diff --git a/.gitignore b/.gitignore index e4c44d0590d59..19db7ac277944 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,7 @@ target/ unit-tests.log work/ docs/.jekyll-metadata +*.crc # For Hive TempStatsStore/ diff --git a/LICENSE b/LICENSE index c2b0d72663b55..b771bd552b762 100644 --- a/LICENSE +++ b/LICENSE @@ -201,102 +201,61 @@ limitations under the License. -======================================================================= -Apache Spark Subcomponents: - -The Apache Spark project contains subcomponents with separate copyright -notices and license terms. Your use of the source code for the these -subcomponents is subject to the terms and conditions of the following -licenses. - - -======================================================================== -For heapq (pyspark/heapq3.py): -======================================================================== - -See license/LICENSE-heapq.txt - -======================================================================== -For SnapTree: -======================================================================== - -See license/LICENSE-SnapTree.txt - -======================================================================== -For jbcrypt: -======================================================================== - -See license/LICENSE-jbcrypt.txt - -======================================================================== -BSD-style licenses -======================================================================== - -The following components are provided under a BSD-style license. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) - (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) - (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) - (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) - (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) - (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:0.9.94 - http://jline.sourceforge.net) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) - (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) - (Interpreter classes (all .scala files in repl/src/main/scala - except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), - and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) - (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) - (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) - (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) - (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) - (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) - (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) - (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) - (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) - (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) - (BSD licence) sbt and sbt-launch-lib.bash - (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) - (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) - (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) - (BSD 2 Clause) Zstd-jni (https://github.com/luben/zstd-jni/blob/master/LICENSE) - (BSD license) Zstd (https://github.com/facebook/zstd/blob/v1.3.1/LICENSE) - -======================================================================== -MIT licenses -======================================================================== - -The following components are provided under the MIT License. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (MIT License) JCL 1.1.1 implemented over SLF4J (org.slf4j:jcl-over-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) JUL to SLF4J bridge (org.slf4j:jul-to-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J API Module (org.slf4j:slf4j-api:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) - (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) - (MIT License) scopt (com.github.scopt:scopt_2.11:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) - (MIT License) jquery (https://jquery.org/license/) - (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) - (MIT License) graphlib-dot (https://github.com/cpettitt/graphlib-dot) - (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) - (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) - (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) - (MIT License) datatables (http://datatables.net/license) - (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) - (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) - (MIT License) blockUI (http://jquery.malsup.com/block/) - (MIT License) RowsGroup (http://datatables.net/license/mit) - (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) - (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) - (MIT License) machinist (https://github.com/typelevel/machinist) +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache Software Foundation License 2.0 +-------------------------------------- + +common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +BSD 3-Clause +------------ + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg \ No newline at end of file diff --git a/LICENSE-binary b/LICENSE-binary new file mode 100644 index 0000000000000..b94ea90de08be --- /dev/null +++ b/LICENSE-binary @@ -0,0 +1,518 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ +This project bundles some components that are also licensed under the Apache +License Version 2.0: + +commons-beanutils:commons-beanutils +org.apache.zookeeper:zookeeper +oro:oro +commons-configuration:commons-configuration +commons-digester:commons-digester +com.chuusai:shapeless_2.11 +com.googlecode.javaewah:JavaEWAH +com.twitter:chill-java +com.twitter:chill_2.11 +com.univocity:univocity-parsers +javax.jdo:jdo-api +joda-time:joda-time +net.sf.opencsv:opencsv +org.apache.derby:derby +org.objenesis:objenesis +org.roaringbitmap:RoaringBitmap +org.scalanlp:breeze-macros_2.11 +org.scalanlp:breeze_2.11 +org.typelevel:macro-compat_2.11 +org.yaml:snakeyaml +org.apache.xbean:xbean-asm5-shaded +com.squareup.okhttp3:logging-interceptor +com.squareup.okhttp3:okhttp +com.squareup.okio:okio +org.apache.spark:spark-catalyst_2.11 +org.apache.spark:spark-kvstore_2.11 +org.apache.spark:spark-launcher_2.11 +org.apache.spark:spark-mllib-local_2.11 +org.apache.spark:spark-network-common_2.11 +org.apache.spark:spark-network-shuffle_2.11 +org.apache.spark:spark-sketch_2.11 +org.apache.spark:spark-tags_2.11 +org.apache.spark:spark-unsafe_2.11 +commons-httpclient:commons-httpclient +com.vlkan:flatbuffers +com.ning:compress-lzf +io.airlift:aircompressor +io.dropwizard.metrics:metrics-core +io.dropwizard.metrics:metrics-ganglia +io.dropwizard.metrics:metrics-graphite +io.dropwizard.metrics:metrics-json +io.dropwizard.metrics:metrics-jvm +org.iq80.snappy:snappy +com.clearspring.analytics:stream +com.jamesmurty.utils:java-xmlbuilder +commons-codec:commons-codec +commons-collections:commons-collections +io.fabric8:kubernetes-client +io.fabric8:kubernetes-model +io.netty:netty +io.netty:netty-all +net.hydromatic:eigenbase-properties +net.sf.supercsv:super-csv +org.apache.arrow:arrow-format +org.apache.arrow:arrow-memory +org.apache.arrow:arrow-vector +org.apache.calcite:calcite-avatica +org.apache.calcite:calcite-core +org.apache.calcite:calcite-linq4j +org.apache.commons:commons-crypto +org.apache.commons:commons-lang3 +org.apache.hadoop:hadoop-annotations +org.apache.hadoop:hadoop-auth +org.apache.hadoop:hadoop-client +org.apache.hadoop:hadoop-common +org.apache.hadoop:hadoop-hdfs +org.apache.hadoop:hadoop-mapreduce-client-app +org.apache.hadoop:hadoop-mapreduce-client-common +org.apache.hadoop:hadoop-mapreduce-client-core +org.apache.hadoop:hadoop-mapreduce-client-jobclient +org.apache.hadoop:hadoop-mapreduce-client-shuffle +org.apache.hadoop:hadoop-yarn-api +org.apache.hadoop:hadoop-yarn-client +org.apache.hadoop:hadoop-yarn-common +org.apache.hadoop:hadoop-yarn-server-common +org.apache.hadoop:hadoop-yarn-server-web-proxy +org.apache.httpcomponents:httpclient +org.apache.httpcomponents:httpcore +org.apache.orc:orc-core +org.apache.orc:orc-mapreduce +org.mortbay.jetty:jetty +org.mortbay.jetty:jetty-util +com.jolbox:bonecp +org.json4s:json4s-ast_2.11 +org.json4s:json4s-core_2.11 +org.json4s:json4s-jackson_2.11 +org.json4s:json4s-scalap_2.11 +com.carrotsearch:hppc +com.fasterxml.jackson.core:jackson-annotations +com.fasterxml.jackson.core:jackson-core +com.fasterxml.jackson.core:jackson-databind +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml +com.fasterxml.jackson.module:jackson-module-jaxb-annotations +com.fasterxml.jackson.module:jackson-module-paranamer +com.fasterxml.jackson.module:jackson-module-scala_2.11 +com.github.mifmif:generex +com.google.code.findbugs:jsr305 +com.google.code.gson:gson +com.google.inject:guice +com.google.inject.extensions:guice-servlet +com.twitter:parquet-hadoop-bundle +commons-beanutils:commons-beanutils-core +commons-cli:commons-cli +commons-dbcp:commons-dbcp +commons-io:commons-io +commons-lang:commons-lang +commons-logging:commons-logging +commons-net:commons-net +commons-pool:commons-pool +io.fabric8:zjsonpatch +javax.inject:javax.inject +javax.validation:validation-api +log4j:apache-log4j-extras +log4j:log4j +net.sf.jpam:jpam +org.apache.avro:avro +org.apache.avro:avro-ipc +org.apache.avro:avro-mapred +org.apache.commons:commons-compress +org.apache.commons:commons-math3 +org.apache.curator:curator-client +org.apache.curator:curator-framework +org.apache.curator:curator-recipes +org.apache.directory.api:api-asn1-api +org.apache.directory.api:api-util +org.apache.directory.server:apacheds-i18n +org.apache.directory.server:apacheds-kerberos-codec +org.apache.htrace:htrace-core +org.apache.ivy:ivy +org.apache.mesos:mesos +org.apache.parquet:parquet-column +org.apache.parquet:parquet-common +org.apache.parquet:parquet-encoding +org.apache.parquet:parquet-format +org.apache.parquet:parquet-hadoop +org.apache.parquet:parquet-jackson +org.apache.thrift:libfb303 +org.apache.thrift:libthrift +org.codehaus.jackson:jackson-core-asl +org.codehaus.jackson:jackson-mapper-asl +org.datanucleus:datanucleus-api-jdo +org.datanucleus:datanucleus-core +org.datanucleus:datanucleus-rdbms +org.lz4:lz4-java +org.spark-project.hive:hive-beeline +org.spark-project.hive:hive-cli +org.spark-project.hive:hive-exec +org.spark-project.hive:hive-jdbc +org.spark-project.hive:hive-metastore +org.xerial.snappy:snappy-java +stax:stax-api +xerces:xercesImpl +org.codehaus.jackson:jackson-jaxrs +org.codehaus.jackson:jackson-xc +org.eclipse.jetty:jetty-client +org.eclipse.jetty:jetty-continuation +org.eclipse.jetty:jetty-http +org.eclipse.jetty:jetty-io +org.eclipse.jetty:jetty-jndi +org.eclipse.jetty:jetty-plus +org.eclipse.jetty:jetty-proxy +org.eclipse.jetty:jetty-security +org.eclipse.jetty:jetty-server +org.eclipse.jetty:jetty-servlet +org.eclipse.jetty:jetty-servlets +org.eclipse.jetty:jetty-util +org.eclipse.jetty:jetty-webapp +org.eclipse.jetty:jetty-xml + +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses-binary/ +for text of these licenses. + + +BSD 2-Clause +------------ + +com.github.luben:zstd-jni +javolution:javolution +com.esotericsoftware:kryo-shaded +com.esotericsoftware:minlog +com.esotericsoftware:reflectasm +com.google.protobuf:protobuf-java +org.codehaus.janino:commons-compiler +org.codehaus.janino:janino +jline:jline +org.jodd:jodd-core + + +BSD 3-Clause +------------ + +dk.brics.automaton:automaton +org.antlr:antlr-runtime +org.antlr:ST4 +org.antlr:stringtemplate +org.antlr:antlr4-runtime +antlr:antlr +com.github.fommil.netlib:core +com.thoughtworks.paranamer:paranamer +org.scala-lang:scala-compiler +org.scala-lang:scala-library +org.scala-lang:scala-reflect +org.scala-lang.modules:scala-parser-combinators_2.11 +org.scala-lang.modules:scala-xml_2.11 +org.fusesource.leveldbjni:leveldbjni-all +net.sourceforge.f2j:arpack_combined_all +xmlenc:xmlenc +net.sf.py4j:py4j +org.jpmml:pmml-model +org.jpmml:pmml-schema + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +org.spire-math:spire-macros_2.11 +org.spire-math:spire_2.11 +org.typelevel:machinist_2.11 +net.razorvine:pyrolite +org.slf4j:jcl-over-slf4j +org.slf4j:jul-to-slf4j +org.slf4j:slf4j-api +org.slf4j:slf4j-log4j12 +com.github.scopt:scopt_2.11 + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Common Development and Distribution License (CDDL) 1.0 +------------------------------------------------------ + +javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html +javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 + + +Common Development and Distribution License (CDDL) 1.1 +------------------------------------------------------ + +javax.annotation:javax.annotation-api https://jcp.org/en/jsr/detail?id=250 +javax.servlet:javax.servlet-api https://javaee.github.io/servlet-spec/ +javax.transaction:jta http://www.oracle.com/technetwork/java/index.html +javax.ws.rs:javax.ws.rs-api https://github.com/jax-rs +javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2 +org.glassfish.hk2:hk2-api https://github.com/javaee/glassfish +org.glassfish.hk2:hk2-locator (same) +org.glassfish.hk2:hk2-utils +org.glassfish.hk2:osgi-resource-locator +org.glassfish.hk2.external:aopalliance-repackaged +org.glassfish.hk2.external:javax.inject +org.glassfish.jersey.bundles.repackaged:jersey-guava +org.glassfish.jersey.containers:jersey-container-servlet +org.glassfish.jersey.containers:jersey-container-servlet-core +org.glassfish.jersey.core:jersey-client +org.glassfish.jersey.core:jersey-common +org.glassfish.jersey.core:jersey-server +org.glassfish.jersey.media:jersey-media-jaxb + + +Mozilla Public License (MPL) 1.1 +-------------------------------- + +com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/ + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +Public Domain +------------- + +aopalliance:aopalliance +net.iharder:base64 +org.tukaani:xz + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg diff --git a/NOTICE b/NOTICE index 6ec240efbf12e..fefe08b38afc5 100644 --- a/NOTICE +++ b/NOTICE @@ -5,663 +5,24 @@ This product includes software developed at The Apache Software Foundation (http://www.apache.org/). -======================================================================== -Common Development and Distribution License 1.0 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.0. See project link for details. - - (CDDL 1.0) Glassfish Jasper (org.mortbay.jetty:jsp-2.1:6.1.14 - http://jetty.mortbay.org/project/modules/jsp-2.1) - (CDDL 1.0) JAX-RS (https://jax-rs-spec.java.net/) - (CDDL 1.0) Servlet Specification 2.5 API (org.mortbay.jetty:servlet-api-2.5:6.1.14 - http://jetty.mortbay.org/project/modules/servlet-api-2.5) - (CDDL 1.0) (GPL2 w/ CPE) javax.annotation API (https://glassfish.java.net/nonav/public/CDDL+GPL.html) - (COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.0) (GNU General Public Library) Streaming API for XML (javax.xml.stream:stax-api:1.0-2 - no url defined) - (Common Development and Distribution License (CDDL) v1.0) JavaBeans Activation Framework (JAF) (javax.activation:activation:1.1 - http://java.sun.com/products/javabeans/jaf/index.jsp) - -======================================================================== -Common Development and Distribution License 1.1 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.1. See project link for details. - - (CDDL 1.1) (GPL2 w/ CPE) org.glassfish.hk2 (https://hk2.java.net) - (CDDL 1.1) (GPL2 w/ CPE) JAXB API bundle for GlassFish V3 (javax.xml.bind:jaxb-api:2.2.2 - https://jaxb.dev.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) JAXB RI (com.sun.xml.bind:jaxb-impl:2.2.3-1 - http://jaxb.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) Jersey 2 (https://jersey.java.net) - -======================================================================== -Common Public License 1.0 -======================================================================== - -The following components are provided under the Common Public 1.0 License. See project link for details. - - (Common Public License Version 1.0) JUnit (junit:junit-dep:4.10 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:3.8.1 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:4.8.2 - http://junit.org) - -======================================================================== -Eclipse Public License 1.0 -======================================================================== - -The following components are provided under the Eclipse Public License 1.0. See project link for details. - - (Eclipse Public License v1.0) Eclipse JDT Core (org.eclipse.jdt:core:3.1.1 - http://www.eclipse.org/jdt/) - -======================================================================== -Mozilla Public License 1.0 -======================================================================== - -The following components are provided under the Mozilla Public License 1.0. See project link for details. - - (GPL) (LGPL) (MPL) JTransforms (com.github.rwl:jtransforms:2.4.0 - http://sourceforge.net/projects/jtransforms/) - (Mozilla Public License Version 1.1) jamon-runtime (org.jamon:jamon-runtime:2.3.1 - http://www.jamon.org/jamon-runtime/) - - - -======================================================================== -NOTICE files -======================================================================== - -The following NOTICEs are pertain to software distributed with this project. - - -// ------------------------------------------------------------------ -// NOTICE file corresponding to the section 4d of The Apache License, -// Version 2.0, in this case for -// ------------------------------------------------------------------ - -Apache Avro -Copyright 2009-2013 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -Apache Commons Codec -Copyright 2002-2009 The Apache Software Foundation - -This product includes software developed by -The Apache Software Foundation (http://www.apache.org/). - --------------------------------------------------------------------------------- -src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java contains -test data from http://aspell.sourceforge.net/test/batch0.tab. - -Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org). Verbatim copying -and distribution of this entire article is permitted in any medium, -provided this notice is preserved. --------------------------------------------------------------------------------- - -Apache HttpComponents HttpClient -Copyright 1999-2011 The Apache Software Foundation - -This project contains annotations derived from JCIP-ANNOTATIONS -Copyright (c) 2005 Brian Goetz and Tim Peierls. See http://www.jcip.net - -Apache HttpComponents HttpCore -Copyright 2005-2011 The Apache Software Foundation - -Curator Recipes -Copyright 2011-2014 The Apache Software Foundation - -Curator Framework -Copyright 2011-2014 The Apache Software Foundation - -Curator Client -Copyright 2011-2014 The Apache Software Foundation - -Apache Geronimo -Copyright 2003-2008 The Apache Software Foundation - -Activation 1.1 -Copyright 2003-2007 The Apache Software Foundation - -Apache Commons Lang -Copyright 2001-2014 The Apache Software Foundation - -This product includes software from the Spring Framework, -under the Apache License 2.0 (see: StringUtils.containsWhitespace()) - -Apache log4j -Copyright 2007 The Apache Software Foundation - -# Compress LZF - -This library contains efficient implementation of LZF compression format, -as well as additional helper classes that build on JDK-provided gzip (deflat) -codec. - -## Licensing - -Library is licensed under Apache License 2.0, as per accompanying LICENSE file. - -## Credit - -Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). -It was started at Ning, inc., as an official Open Source process used by -platform backend, but after initial versions has been developed outside of -Ning by supporting community. - -Other contributors include: - -* Jon Hartlaub (first versions of streaming reader/writer; unit tests) -* Cedrik Lime: parallel LZF implementation - -Various community members have contributed bug reports, and suggested minor -fixes; these can be found from file "VERSION.txt" in SCM. - -Objenesis -Copyright 2006-2009 Joe Walnes, Henri Tremblay, Leonardo Mesquita - -Apache Commons Net -Copyright 2001-2010 The Apache Software Foundation - - The Netty Project - ================= - -Please visit the Netty web site for more information: - - * http://netty.io/ - -Copyright 2011 The Netty Project - -The Netty Project licenses this file to you under the Apache License, -version 2.0 (the "License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at: - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -License for the specific language governing permissions and limitations -under the License. - -Also, please refer to each LICENSE..txt file, which is located in -the 'license' directory of the distribution file, for the license terms of the -components that this product depends on. - -------------------------------------------------------------------------------- -This product contains the extensions to Java Collections Framework which has -been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: - - * LICENSE: - * license/LICENSE.jsr166y.txt (Public Domain) - * HOMEPAGE: - * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ - * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ - -This product contains a modified version of Robert Harder's Public Domain -Base64 Encoder and Decoder, which can be obtained at: - - * LICENSE: - * license/LICENSE.base64.txt (Public Domain) - * HOMEPAGE: - * http://iharder.sourceforge.net/current/java/base64/ - -This product contains a modified version of 'JZlib', a re-implementation of -zlib in pure Java, which can be obtained at: - - * LICENSE: - * license/LICENSE.jzlib.txt (BSD Style License) - * HOMEPAGE: - * http://www.jcraft.com/jzlib/ - -This product optionally depends on 'Protocol Buffers', Google's data -interchange format, which can be obtained at: - - * LICENSE: - * license/LICENSE.protobuf.txt (New BSD License) - * HOMEPAGE: - * http://code.google.com/p/protobuf/ - -This product optionally depends on 'SLF4J', a simple logging facade for Java, -which can be obtained at: - - * LICENSE: - * license/LICENSE.slf4j.txt (MIT License) - * HOMEPAGE: - * http://www.slf4j.org/ - -This product optionally depends on 'Apache Commons Logging', a logging -framework, which can be obtained at: - - * LICENSE: - * license/LICENSE.commons-logging.txt (Apache License 2.0) - * HOMEPAGE: - * http://commons.apache.org/logging/ - -This product optionally depends on 'Apache Log4J', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.log4j.txt (Apache License 2.0) - * HOMEPAGE: - * http://logging.apache.org/log4j/ - -This product optionally depends on 'JBoss Logging', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) - * HOMEPAGE: - * http://anonsvn.jboss.org/repos/common/common-logging-spi/ - -This product optionally depends on 'Apache Felix', an open source OSGi -framework implementation, which can be obtained at: - - * LICENSE: - * license/LICENSE.felix.txt (Apache License 2.0) - * HOMEPAGE: - * http://felix.apache.org/ - -This product optionally depends on 'Webbit', a Java event based -WebSocket and HTTP server: - - * LICENSE: - * license/LICENSE.webbit.txt (BSD License) - * HOMEPAGE: - * https://github.com/joewalnes/webbit - -# Jackson JSON processor - -Jackson is a high-performance, Free/Open Source JSON processing library. -It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has -been in development since 2007. -It is currently developed by a community of developers, as well as supported -commercially by FasterXML.com. - -Jackson core and extension components may be licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -## Credits - -A list of contributors may be found from CREDITS file, which is included -in some artifacts (usually source distributions); but is always available -from the source code management (SCM) system project uses. - -Jackson core and extension components may licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -mesos -Copyright 2014 The Apache Software Foundation - -Apache Thrift -Copyright 2006-2010 The Apache Software Foundation. - - Apache Ant - Copyright 1999-2013 The Apache Software Foundation - - The task is based on code Copyright (c) 2002, Landmark - Graphics Corp that has been kindly donated to the Apache Software - Foundation. - -Apache Commons IO -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons Math -Copyright 2001-2013 The Apache Software Foundation - -=============================================================================== - -The inverse error function implementation in the Erf class is based on CUDA -code developed by Mike Giles, Oxford-Man Institute of Quantitative Finance, -and published in GPU Computing Gems, volume 2, 2010. -=============================================================================== - -The BracketFinder (package org.apache.commons.math3.optimization.univariate) -and PowellOptimizer (package org.apache.commons.math3.optimization.general) -classes are based on the Python code in module "optimize.py" (version 0.5) -developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/) -Copyright © 2003-2009 SciPy Developers. -=============================================================================== - -The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, -RelationShip, SimplexSolver and SimplexTableau classes in package -org.apache.commons.math3.optimization.linear include software developed by -Benjamin McCann (http://www.benmccann.com) and distributed with -the following copyright: Copyright 2009 Google Inc. -=============================================================================== - -This product includes software developed by the -University of Chicago, as Operator of Argonne National -Laboratory. -The LevenbergMarquardtOptimizer class in package -org.apache.commons.math3.optimization.general includes software -translated from the lmder, lmpar and qrsolv Fortran routines -from the Minpack package -Minpack Copyright Notice (1999) University of Chicago. All rights reserved -=============================================================================== - -The GraggBulirschStoerIntegrator class in package -org.apache.commons.math3.ode.nonstiff includes software translated -from the odex Fortran routine developed by E. Hairer and G. Wanner. -Original source copyright: -Copyright (c) 2004, Ernst Hairer -=============================================================================== - -The EigenDecompositionImpl class in package -org.apache.commons.math3.linear includes software translated -from some LAPACK Fortran routines. Original source copyright: -Copyright (c) 1992-2008 The University of Tennessee. All rights reserved. -=============================================================================== - -The MersenneTwister class in package org.apache.commons.math3.random -includes software translated from the 2002-01-26 version of -the Mersenne-Twister generator written in C by Makoto Matsumoto and Takuji -Nishimura. Original source copyright: -Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, -All rights reserved -=============================================================================== - -The LocalizedFormatsTest class in the unit tests is an adapted version of -the OrekitMessagesTest class from the orekit library distributed under the -terms of the Apache 2 licence. Original source copyright: -Copyright 2010 CS Systèmes d'Information -=============================================================================== - -The HermiteInterpolator class and its corresponding test have been imported from -the orekit library distributed under the terms of the Apache 2 licence. Original -source copyright: -Copyright 2010-2012 CS Systèmes d'Information -=============================================================================== - -The creation of the package "o.a.c.m.analysis.integration.gauss" was inspired -by an original code donated by Sébastien Brisard. -=============================================================================== - -The complete text of licenses and disclaimers associated with the the original -sources enumerated above at the time of code translation are in the LICENSE.txt -file. - -This product currently only contains code developed by authors -of specific components, as identified by the source code files; -if such notes are missing files have been created by -Tatu Saloranta. - -For additional credits (generally to people who reported problems) -see CREDITS file. - -Apache Commons Lang -Copyright 2001-2011 The Apache Software Foundation - -Apache Commons Compress -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons CLI -Copyright 2001-2009 The Apache Software Foundation - -Google Guice - Extensions - Servlet -Copyright 2006-2011 Google, Inc. - -Google Guice - Core Library -Copyright 2006-2011 Google, Inc. - -Apache Jakarta HttpClient -Copyright 1999-2007 The Apache Software Foundation - -Apache Hive -Copyright 2008-2013 The Apache Software Foundation - -This product includes software developed by The Apache Software -Foundation (http://www.apache.org/). - -This product includes software developed by The JDBM Project -(http://jdbm.sourceforge.net/). - -This product includes/uses ANTLR (http://www.antlr.org/), -Copyright (c) 2003-2011, Terrence Parr. - -This product includes/uses StringTemplate (http://www.stringtemplate.org/), -Copyright (c) 2011, Terrence Parr. - -This product includes/uses ASM (http://asm.ow2.org/), -Copyright (c) 2000-2007 INRIA, France Telecom. - -This product includes/uses JLine (http://jline.sourceforge.net/), -Copyright (c) 2002-2006, Marc Prud'hommeaux . - -This product includes/uses SQLLine (http://sqlline.sourceforge.net), -Copyright (c) 2002, 2003, 2004, 2005 Marc Prud'hommeaux . - -This product includes/uses SLF4J (http://www.slf4j.org/), -Copyright (c) 2004-2010 QOS.ch - -This product includes/uses Bootstrap (http://twitter.github.com/bootstrap/), -Copyright (c) 2012 Twitter, Inc. - -This product includes/uses Glyphicons (http://glyphicons.com/), -Copyright (c) 2010 - 2012 Jan Kovarík - -This product includes DataNucleus (http://www.datanucleus.org/) -Copyright 2008-2008 DataNucleus - -This product includes Guava (http://code.google.com/p/guava-libraries/) -Copyright (C) 2006 Google Inc. - -This product includes JavaEWAH (http://code.google.com/p/javaewah/) -Copyright (C) 2011 Google Inc. - -Apache Commons Pool -Copyright 1999-2009 The Apache Software Foundation - -This product includes/uses Kubernetes & OpenShift 3 Java Client (https://github.com/fabric8io/kubernetes-client) -Copyright (C) 2015 Red Hat, Inc. - -This product includes/uses OkHttp (https://github.com/square/okhttp) -Copyright (C) 2012 The Android Open Source Project - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the DataNucleus distribution. == -========================================================================= - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Erik Bengtson -Andy Jefferson - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Joerg von Frantzius -Thomas Marti -Barry Haddow -Marco Schulze -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Marcus Mennemeier -Xuan Baldauf -Eric Sultan - -=================================================================== -This product also includes software developed by the TJDO project -(http://tjdo.sourceforge.net/). -=================================================================== - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Andy Jefferson -Erik Bengtson -Joerg von Frantzius -Marco Schulze - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Barry Haddow -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Anton Troshin (Timesten) - -=================================================================== -This product also includes software developed by the Apache Commons project -(http://commons.apache.org/). -=================================================================== - -Apache Java Data Objects (JDO) -Copyright 2005-2006 The Apache Software Foundation - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the Apache Derby distribution. == -========================================================================= - -Apache Derby -Copyright 2004-2008 The Apache Software Foundation - -Portions of Derby were originally developed by -International Business Machines Corporation and are -licensed to the Apache Software Foundation under the -"Software Grant and Corporate Contribution License Agreement", -informally known as the "Derby CLA". -The following copyright notice(s) were affixed to portions of the code -with which this file is now or was at one time distributed -and are placed here unaltered. - -(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. - -(C) Copyright IBM Corp. 2003. - -The portion of the functionTests under 'nist' was originally -developed by the National Institute of Standards and Technology (NIST), -an agency of the United States Department of Commerce, and adapted by -International Business Machines Corporation in accordance with the NIST -Software Acknowledgment and Redistribution document at -http://www.itl.nist.gov/div897/ctg/sql_form.htm - -Apache Commons Collections -Copyright 2001-2008 The Apache Software Foundation - -Apache Commons Configuration -Copyright 2001-2008 The Apache Software Foundation - -Apache Jakarta Commons Digester -Copyright 2001-2006 The Apache Software Foundation - -Apache Commons BeanUtils -Copyright 2000-2008 The Apache Software Foundation - -Apache Avro Mapred API -Copyright 2009-2013 The Apache Software Foundation - -Apache Avro IPC -Copyright 2009-2013 The Apache Software Foundation - - -Vis.js -Copyright 2010-2015 Almende B.V. - -Vis.js is dual licensed under both - - * The Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0 - - and - - * The MIT License - http://opensource.org/licenses/MIT - -Vis.js may be distributed under either license. - - -Vis.js uses and redistributes the following third-party libraries: - -- component-emitter - https://github.com/component/emitter - The MIT License - -- hammer.js - http://hammerjs.github.io/ - The MIT License - -- moment.js - http://momentjs.com/ - The MIT License - -- keycharm - https://github.com/AlexDM0/keycharm - The MIT License - -=============================================================================== - -The CSS style for the navigation sidebar of the documentation was originally -submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project -is distributed under the 3-Clause BSD license. -=============================================================================== - -For CSV functionality: - -/* - * Copyright 2014 Databricks - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2015 Ayasdi Inc - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -=============================================================================== -For dev/sparktestsupport/toposort.py: - -Copyright 2014 True Blade Systems, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100644 index 0000000000000..b707c436983f7 --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,1174 @@ +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. + + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Hive Beeline +Copyright 2016 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro +Copyright 2009-2014 The Apache Software Foundation + +This product currently only contains code developed by authors +of specific components, as identified by the source code files; +if such notes are missing files have been created by +Tatu Saloranta. + +For additional credits (generally to people who reported problems) +see CREDITS file. + +Apache Commons Compress +Copyright 2002-2012 The Apache Software Foundation + +This product includes software developed by +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro Mapred API +Copyright 2009-2014 The Apache Software Foundation + +Apache Avro IPC +Copyright 2009-2014 The Apache Software Foundation + +Objenesis +Copyright 2006-2013 Joe Walnes, Henri Tremblay, Leonardo Mesquita + +Apache XBean :: ASM 5 shaded (repackaged) +Copyright 2005-2015 The Apache Software Foundation + +-------------------------------------- + +This product includes software developed at +OW2 Consortium (http://asm.ow2.org/) + +This product includes software developed by The Apache Software +Foundation (http://www.apache.org/). + +The binary distribution of this product bundles binaries of +org.iq80.leveldb:leveldb-api (https://github.com/dain/leveldb), which has the +following notices: +* Copyright 2011 Dain Sundstrom +* Copyright 2011 FuseSource Corp. http://fusesource.com + +The binary distribution of this product bundles binaries of +org.fusesource.hawtjni:hawtjni-runtime (https://github.com/fusesource/hawtjni), +which has the following notices: +* This product includes software developed by FuseSource Corp. + http://fusesource.com +* This product includes software developed at + Progress Software Corporation and/or its subsidiaries or affiliates. +* This product includes software developed by IBM Corporation and others. + +The binary distribution of this product bundles binaries of +Gson 2.2.4, +which has the following notices: + + The Netty Project + ================= + +Please visit the Netty web site for more information: + + * http://netty.io/ + +Copyright 2014 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified portion of 'Webbit', an event based +WebSocket and HTTP server, which can be obtained at: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product contains a modified portion of 'SLF4J', a simple logging +facade for Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * http://www.slf4j.org/ + +This product contains a modified portion of 'ArrayDeque', written by Josh +Bloch of Google, Inc: + + * LICENSE: + * license/LICENSE.deque.txt (Public Domain) + +This product contains a modified portion of 'Apache Harmony', an open source +Java SE, which can be obtained at: + + * LICENSE: + * license/LICENSE.harmony.txt (Apache License 2.0) + * HOMEPAGE: + * http://archive.apache.org/dist/harmony/ + +This product contains a modified version of Roland Kuhn's ASL2 +AbstractNodeQueue, which is based on Dmitriy Vyukov's non-intrusive MPSC queue. +It can be obtained at: + + * LICENSE: + * license/LICENSE.abstractnodequeue.txt (Public Domain) + * HOMEPAGE: + * https://github.com/akka/akka/blob/wip-2.2.3-for-scala-2.11/akka-actor/src/main/java/akka/dispatch/AbstractNodeQueue.java + +This product contains a modified portion of 'jbzip2', a Java bzip2 compression +and decompression library written by Matthew J. Francis. It can be obtained at: + + * LICENSE: + * license/LICENSE.jbzip2.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jbzip2/ + +This product contains a modified portion of 'libdivsufsort', a C API library to construct +the suffix array and the Burrows-Wheeler transformed string for any input string of +a constant-size alphabet written by Yuta Mori. It can be obtained at: + + * LICENSE: + * license/LICENSE.libdivsufsort.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/libdivsufsort/ + +This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, + which can be obtained at: + + * LICENSE: + * license/LICENSE.jctools.txt (ASL2 License) + * HOMEPAGE: + * https://github.com/JCTools/JCTools + +This product optionally depends on 'JZlib', a re-implementation of zlib in +pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product optionally depends on 'Compress-LZF', a Java library for encoding and +decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: + + * LICENSE: + * license/LICENSE.compress-lzf.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/ning/compress + +This product optionally depends on 'lz4', a LZ4 Java compression +and decompression library written by Adrien Grand. It can be obtained at: + + * LICENSE: + * license/LICENSE.lz4.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jpountz/lz4-java + +This product optionally depends on 'lzma-java', a LZMA Java compression +and decompression library, which can be obtained at: + + * LICENSE: + * license/LICENSE.lzma-java.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jponge/lzma-java + +This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +and decompression library written by William Kinney. It can be obtained at: + + * LICENSE: + * license/LICENSE.jfastlz.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jfastlz/ + +This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/protobuf/ + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * http://www.bouncycastle.org/ + +This product optionally depends on 'Snappy', a compression library produced +by Google Inc, which can be obtained at: + + * LICENSE: + * license/LICENSE.snappy.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/snappy/ + +This product optionally depends on 'JBoss Marshalling', an alternative Java +serialization API, which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-marshalling.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://www.jboss.org/jbossmarshalling + +This product optionally depends on 'Caliper', Google's micro- +benchmarking framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.caliper.txt (Apache License 2.0) + * HOMEPAGE: + * http://code.google.com/p/caliper/ + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * http://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, which +can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * http://logging.apache.org/log4j/ + +This product optionally depends on 'Aalto XML', an ultra-high performance +non-blocking XML processor, which can be obtained at: + + * LICENSE: + * license/LICENSE.aalto-xml.txt (Apache License 2.0) + * HOMEPAGE: + * http://wiki.fasterxml.com/AaltoHome + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: + + * LICENSE: + * license/LICENSE.hpack.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/twitter/hpack + +This product contains a modified portion of 'Apache Commons Lang', a Java library +provides utilities for the java.lang API, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-lang.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/proper/commons-lang/ + +The binary distribution of this product bundles binaries of +Commons Codec 1.4, +which has the following notices: + * src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.javacontains test data from http://aspell.net/test/orig/batch0.tab.Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + =============================================================================== + The content of package org.apache.commons.codec.language.bm has been translated + from the original php source code available at http://stevemorse.org/phoneticinfo.htm + with permission from the original authors. + Original source copyright:Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +The binary distribution of this product bundles binaries of +Commons Lang 2.6, +which has the following notices: + * This product includes software from the Spring Framework,under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +The binary distribution of this product bundles binaries of +Apache Log4j 1.2.17, +which has the following notices: + * ResolverUtil.java + Copyright 2005-2006 Tim Fennell + Dumbster SMTP test server + Copyright 2004 Jason Paul Kitchen + TypeUtil.java + Copyright 2002-2012 Ramnivas Laddad, Juergen Hoeller, Chris Beams + +The binary distribution of this product bundles binaries of +Jetty 6.1.26, +which has the following notices: + * ============================================================== + Jetty Web Container + Copyright 1995-2016 Mort Bay Consulting Pty Ltd. + ============================================================== + + The Jetty Web Container is Copyright Mort Bay Consulting Pty Ltd + unless otherwise noted. + + Jetty is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html + + and + + * The Eclipse Public 1.0 License + http://www.eclipse.org/legal/epl-v10.html + + Jetty may be distributed under either license. + + ------ + Eclipse + + The following artifacts are EPL. + * org.eclipse.jetty.orbit:org.eclipse.jdt.core + + The following artifacts are EPL and ASL2. + * org.eclipse.jetty.orbit:javax.security.auth.message + + The following artifacts are EPL and CDDL 1.0. + * org.eclipse.jetty.orbit:javax.mail.glassfish + + ------ + Oracle + + The following artifacts are CDDL + GPLv2 with classpath exception. + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + * javax.servlet:javax.servlet-api + * javax.annotation:javax.annotation-api + * javax.transaction:javax.transaction-api + * javax.websocket:javax.websocket-api + + ------ + Oracle OpenJDK + + If ALPN is used to negotiate HTTP/2 connections, then the following + artifacts may be included in the distribution or downloaded when ALPN + module is selected. + + * java.sun.security.ssl + + These artifacts replace/modify OpenJDK classes. The modififications + are hosted at github and both modified and original are under GPL v2 with + classpath exceptions. + http://openjdk.java.net/legal/gplv2+ce.html + + ------ + OW2 + + The following artifacts are licensed by the OW2 Foundation according to the + terms of http://asm.ow2.org/license.html + + org.ow2.asm:asm-commons + org.ow2.asm:asm + + ------ + Apache + + The following artifacts are ASL2 licensed. + + org.apache.taglibs:taglibs-standard-spec + org.apache.taglibs:taglibs-standard-impl + + ------ + MortBay + + The following artifacts are ASL2 licensed. Based on selected classes from + following Apache Tomcat jars, all ASL2 licensed. + + org.mortbay.jasper:apache-jsp + org.apache.tomcat:tomcat-jasper + org.apache.tomcat:tomcat-juli + org.apache.tomcat:tomcat-jsp-api + org.apache.tomcat:tomcat-el-api + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-api + org.apache.tomcat:tomcat-util-scan + org.apache.tomcat:tomcat-util + + org.mortbay.jasper:apache-el + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-el-api + + ------ + Mortbay + + The following artifacts are CDDL + GPLv2 with classpath exception. + + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + org.eclipse.jetty.toolchain:jetty-schemas + + ------ + Assorted + + The UnixCrypt.java code implements the one way cryptography used by + Unix systems for simple password protection. Copyright 1996 Aki Yoshida, + modified April 2001 by Iris Van den Broeke, Daniel Deville. + Permission to use, copy, modify and distribute UnixCrypt + for non-commercial or commercial purposes and without fee is + granted provided that the copyright notice appears in all copies./ + +The binary distribution of this product bundles binaries of +Snappy for Java 1.0.4.1, +which has the following notices: + * This product includes software developed by Google + Snappy: http://code.google.com/p/snappy/ (New BSD License) + + This product includes software developed by Apache + PureJavaCrc32C from apache-hadoop-common http://hadoop.apache.org/ + (Apache 2.0 license) + + This library contains statically linked libstdc++. This inclusion is allowed by + "GCC RUntime Library Exception" + http://gcc.gnu.org/onlinedocs/libstdc++/manual/license.html + + == Contributors == + * Tatu Saloranta + * Providing benchmark suite + * Alec Wysoker + * Performance and memory usage improvement + +The binary distribution of this product bundles binaries of +Xerces2 Java Parser 2.9.1, +which has the following notices: + * ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Apache Xerces Java + Copyright 1999-2007 The Apache Software Foundation + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +Apache Commons Collections +Copyright 2001-2015 The Apache Software Foundation + +Apache Commons Configuration +Copyright 2001-2008 The Apache Software Foundation + +Apache Jakarta Commons Digester +Copyright 2001-2006 The Apache Software Foundation + +Apache Commons BeanUtils +Copyright 2000-2008 The Apache Software Foundation + +ApacheDS Protocol Kerberos Codec +Copyright 2003-2013 The Apache Software Foundation + +ApacheDS I18n +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory API ASN.1 API +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory LDAP API Utilities +Copyright 2003-2013 The Apache Software Foundation + +Curator Client +Copyright 2011-2015 The Apache Software Foundation + +htrace-core +Copyright 2015 The Apache Software Foundation + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. + +Apache HttpCore +Copyright 2005-2017 The Apache Software Foundation + +Curator Recipes +Copyright 2011-2015 The Apache Software Foundation + +Curator Framework +Copyright 2011-2015 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2016 The Apache Software Foundation + +This product includes software from the Spring Framework, +under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +Apache Commons Math +Copyright 2001-2015 The Apache Software Foundation + +This product includes software developed for Orekit by +CS Systèmes d'Information (http://www.c-s.fr/) +Copyright 2010-2012 CS Systèmes d'Information + +Apache log4j +Copyright 2007 The Apache Software Foundation + +# Compress LZF + +This library contains efficient implementation of LZF compression format, +as well as additional helper classes that build on JDK-provided gzip (deflat) +codec. + +Library is licensed under Apache License 2.0, as per accompanying LICENSE file. + +## Credit + +Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). +It was started at Ning, inc., as an official Open Source process used by +platform backend, but after initial versions has been developed outside of +Ning by supporting community. + +Other contributors include: + +* Jon Hartlaub (first versions of streaming reader/writer; unit tests) +* Cedrik Lime: parallel LZF implementation + +Various community members have contributed bug reports, and suggested minor +fixes; these can be found from file "VERSION.txt" in SCM. + +Apache Commons Net +Copyright 2001-2012 The Apache Software Foundation + +Copyright 2011 The Netty Project + +http://www.apache.org/licenses/LICENSE-2.0 + +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD Style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: + +This product optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: + +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: + +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ + +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: + + * LICENSE: + * license/LICENSE.felix.txt (Apache License 2.0) + * HOMEPAGE: + * http://felix.apache.org/ + +Jackson core and extension components may be licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +Apache Ivy (TM) +Copyright 2007-2014 The Apache Software Foundation + +Portions of Ivy were originally developed at +Jayasoft SARL (http://www.jayasoft.fr/) +and are licensed to the Apache Software Foundation under the +"Software Grant License Agreement" + +SSH and SFTP support is provided by the JCraft JSch package, +which is open source software, available under +the terms of a BSD style license. +The original software and related information is available +at http://www.jcraft.com/jsch/. + + +ORC Core +Copyright 2013-2018 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2011 The Apache Software Foundation + +ORC MapReduce +Copyright 2013-2018 The Apache Software Foundation + +Apache Parquet Format +Copyright 2017 The Apache Software Foundation + +Arrow Vectors +Copyright 2017 The Apache Software Foundation + +Arrow Format +Copyright 2017 The Apache Software Foundation + +Arrow Memory +Copyright 2017 The Apache Software Foundation + +Apache Commons CLI +Copyright 2001-2009 The Apache Software Foundation + +Google Guice - Extensions - Servlet +Copyright 2006-2011 Google, Inc. + +Apache Commons IO +Copyright 2002-2012 The Apache Software Foundation + +Google Guice - Core Library +Copyright 2006-2011 Google, Inc. + +mesos +Copyright 2017 The Apache Software Foundation + +Apache Parquet Hadoop Bundle (Incubating) +Copyright 2015 The Apache Software Foundation + +Hive Query Language +Copyright 2016 The Apache Software Foundation + +Apache Extras Companion for log4j 1.2. +Copyright 2007 The Apache Software Foundation + +Hive Metastore +Copyright 2016 The Apache Software Foundation + +Apache Commons Logging +Copyright 2003-2013 The Apache Software Foundation + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, == +== Version 2.0, in this case for the DataNucleus distribution. == +========================================================================= + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Erik Bengtson +Andy Jefferson + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Andy Jefferson +Erik Bengtson +Joerg von Frantzius +Marco Schulze + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Barry Haddow +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Anton Troshin (Timesten) + +=================================================================== +This product also includes software developed by the TJDO project +(http://tjdo.sourceforge.net/). +=================================================================== + +=================================================================== +This product also includes software developed by the Apache Commons project +(http://commons.apache.org/). +=================================================================== + +Apache Commons Pool +Copyright 1999-2009 The Apache Software Foundation + +Apache Commons DBCP +Copyright 2001-2010 The Apache Software Foundation + +Apache Java Data Objects (JDO) +Copyright 2005-2006 The Apache Software Foundation + +Apache Jakarta HttpClient +Copyright 1999-2007 The Apache Software Foundation + +Calcite Avatica +Copyright 2012-2015 The Apache Software Foundation + +Calcite Core +Copyright 2012-2015 The Apache Software Foundation + +Calcite Linq4j +Copyright 2012-2015 The Apache Software Foundation + +Apache HttpClient +Copyright 1999-2017 The Apache Software Foundation + +Apache Commons Codec +Copyright 2002-2014 The Apache Software Foundation + +src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java +contains test data from http://aspell.net/test/orig/batch0.tab. +Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + +=============================================================================== + +The content of package org.apache.commons.codec.language.bm has been translated +from the original php source code available at http://stevemorse.org/phoneticinfo.htm +with permission from the original authors. +Original source copyright: +Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Joerg von Frantzius +Thomas Marti +Barry Haddow +Marco Schulze +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Marcus Mennemeier +Xuan Baldauf +Eric Sultan + +Apache Thrift +Copyright 2006-2010 The Apache Software Foundation. + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, +== Version 2.0, in this case for the Apache Derby distribution. +== +== DO NOT EDIT THIS FILE DIRECTLY. IT IS GENERATED +== BY THE buildnotice TARGET IN THE TOP LEVEL build.xml FILE. +== +========================================================================= + +Apache Derby +Copyright 2004-2015 The Apache Software Foundation + +========================================================================= + +Portions of Derby were originally developed by +International Business Machines Corporation and are +licensed to the Apache Software Foundation under the +"Software Grant and Corporate Contribution License Agreement", +informally known as the "Derby CLA". +The following copyright notice(s) were affixed to portions of the code +with which this file is now or was at one time distributed +and are placed here unaltered. + +(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. + +(C) Copyright IBM Corp. 2003. + +The portion of the functionTests under 'nist' was originally +developed by the National Institute of Standards and Technology (NIST), +an agency of the United States Department of Commerce, and adapted by +International Business Machines Corporation in accordance with the NIST +Software Acknowledgment and Redistribution document at +http://www.itl.nist.gov/div897/ctg/sql_form.htm + +The JDBC apis for small devices and JDBC3 (under java/stubs/jsr169 and +java/stubs/jdbc3) were produced by trimming sources supplied by the +Apache Harmony project. In addition, the Harmony SerialBlob and +SerialClob implementations are used. The following notice covers the Harmony sources: + +Portions of Harmony were originally developed by +Intel Corporation and are licensed to the Apache Software +Foundation under the "Software Grant and Corporate Contribution +License Agreement", informally known as the "Intel Harmony CLA". + +The Derby build relies on source files supplied by the Apache Felix +project. The following notice covers the Felix files: + + Apache Felix Main + Copyright 2008 The Apache Software Foundation + + I. Included Software + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Licensed under the Apache License 2.0. + + This product includes software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + This product includes software from http://kxml.sourceforge.net. + Copyright (c) 2002,2003, Stefan Haustein, Oberhausen, Rhld., Germany. + Licensed under BSD License. + + II. Used Software + + This product uses software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + III. License Summary + - Apache License 2.0 + - BSD License + +The Derby build relies on jar files supplied by the Apache Lucene +project. The following notice covers the Lucene files: + +Apache Lucene +Copyright 2013 The Apache Software Foundation + +Includes software from other Apache Software Foundation projects, +including, but not limited to: + - Apache Ant + - Apache Jakarta Regexp + - Apache Commons + - Apache Xerces + +ICU4J, (under analysis/icu) is licensed under an MIT styles license +and Copyright (c) 1995-2008 International Business Machines Corporation and others + +Some data files (under analysis/icu/src/data) are derived from Unicode data such +as the Unicode Character Database. See http://unicode.org/copyright.html for more +details. + +Brics Automaton (under core/src/java/org/apache/lucene/util/automaton) is +BSD-licensed, created by Anders Møller. See http://www.brics.dk/automaton/ + +The levenshtein automata tables (under core/src/java/org/apache/lucene/util/automaton) were +automatically generated with the moman/finenight FSA library, created by +Jean-Philippe Barrette-LaPierre. This library is available under an MIT license, +see http://sites.google.com/site/rrettesite/moman and +http://bitbucket.org/jpbarrette/moman/overview/ + +The class org.apache.lucene.util.WeakIdentityMap was derived from +the Apache CXF project and is Apache License 2.0. + +The Google Code Prettify is Apache License 2.0. +See http://code.google.com/p/google-code-prettify/ + +JUnit (junit-4.10) is licensed under the Common Public License v. 1.0 +See http://junit.sourceforge.net/cpl-v10.html + +This product includes code (JaspellTernarySearchTrie) from Java Spelling Checkin +g Package (jaspell): http://jaspell.sourceforge.net/ +License: The BSD License (http://www.opensource.org/licenses/bsd-license.php) + +The snowball stemmers in + analysis/common/src/java/net/sf/snowball +were developed by Martin Porter and Richard Boulton. +The snowball stopword lists in + analysis/common/src/resources/org/apache/lucene/analysis/snowball +were developed by Martin Porter and Richard Boulton. +The full snowball package is available from + http://snowball.tartarus.org/ + +The KStem stemmer in + analysis/common/src/org/apache/lucene/analysis/en +was developed by Bob Krovetz and Sergio Guzman-Lara (CIIR-UMass Amherst) +under the BSD-license. + +The Arabic,Persian,Romanian,Bulgarian, and Hindi analyzers (common) come with a default +stopword list that is BSD-licensed created by Jacques Savoy. These files reside in: +analysis/common/src/resources/org/apache/lucene/analysis/ar/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/fa/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/ro/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/bg/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/hi/stopwords.txt +See http://members.unine.ch/jacques.savoy/clef/index.html. + +The German,Spanish,Finnish,French,Hungarian,Italian,Portuguese,Russian and Swedish light stemmers +(common) are based on BSD-licensed reference implementations created by Jacques Savoy and +Ljiljana Dolamic. These files reside in: +analysis/common/src/java/org/apache/lucene/analysis/de/GermanLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/de/GermanMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/es/SpanishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fi/FinnishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/hu/HungarianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/it/ItalianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/pt/PortugueseLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/ru/RussianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/sv/SwedishLightStemmer.java + +The Stempel analyzer (stempel) includes BSD-licensed software developed +by the Egothor project http://egothor.sf.net/, created by Leo Galambos, Martin Kvapil, +and Edmond Nolan. + +The Polish analyzer (stempel) comes with a default +stopword list that is BSD-licensed created by the Carrot2 project. The file resides +in stempel/src/resources/org/apache/lucene/analysis/pl/stopwords.txt. +See http://project.carrot2.org/license.html. + +The SmartChineseAnalyzer source code (smartcn) was +provided by Xiaoping Gao and copyright 2009 by www.imdict.net. + +WordBreakTestUnicode_*.java (under modules/analysis/common/src/test/) +is derived from Unicode data such as the Unicode Character Database. +See http://unicode.org/copyright.html for more details. + +The Morfologik analyzer (morfologik) includes BSD-licensed software +developed by Dawid Weiss and Marcin Miłkowski (http://morfologik.blogspot.com/). + +Morfologik uses data from Polish ispell/myspell dictionary +(http://www.sjp.pl/slownik/en/) licenced on the terms of (inter alia) +LGPL and Creative Commons ShareAlike. + +Morfologic includes data from BSD-licensed dictionary of Polish (SGJP) +(http://sgjp.pl/morfeusz/) + +Servlet-api.jar and javax.servlet-*.jar are under the CDDL license, the original +source code for this can be found at http://www.eclipse.org/jetty/downloads.php + +=========================================================================== +Kuromoji Japanese Morphological Analyzer - Apache Lucene Integration +=========================================================================== + +This software includes a binary and/or source version of data from + + mecab-ipadic-2.7.0-20070801 + +which can be obtained from + + http://atilika.com/releases/mecab-ipadic/mecab-ipadic-2.7.0-20070801.tar.gz + +or + + http://jaist.dl.sourceforge.net/project/mecab/mecab-ipadic/2.7.0-20070801/mecab-ipadic-2.7.0-20070801.tar.gz + +=========================================================================== +mecab-ipadic-2.7.0-20070801 Notice +=========================================================================== + +Nara Institute of Science and Technology (NAIST), +the copyright holders, disclaims all warranties with regard to this +software, including all implied warranties of merchantability and +fitness, in no event shall NAIST be liable for +any special, indirect or consequential damages or any damages +whatsoever resulting from loss of use, data or profits, whether in an +action of contract, negligence or other tortuous action, arising out +of or in connection with the use or performance of this software. + +A large portion of the dictionary entries +originate from ICOT Free Software. The following conditions for ICOT +Free Software applies to the current dictionary as well. + +Each User may also freely distribute the Program, whether in its +original form or modified, to any third party or parties, PROVIDED +that the provisions of Section 3 ("NO WARRANTY") will ALWAYS appear +on, or be attached to, the Program, which is distributed substantially +in the same form as set out herein and that such intended +distribution, if actually made, will neither violate or otherwise +contravene any of the laws and regulations of the countries having +jurisdiction over the User or the intended distribution itself. + +NO WARRANTY + +The program was produced on an experimental basis in the course of the +research and development conducted during the project and is provided +to users as so produced on an experimental basis. Accordingly, the +program is provided without any warranty whatsoever, whether express, +implied, statutory or otherwise. The term "warranty" used herein +includes, but is not limited to, any warranty of the quality, +performance, merchantability and fitness for a particular purpose of +the program and the nonexistence of any infringement or violation of +any right of any third party. + +Each user of the program will agree and understand, and be deemed to +have agreed and understood, that there is no warranty whatsoever for +the program and, accordingly, the entire risk arising from or +otherwise connected with the program is assumed by the user. + +Therefore, neither ICOT, the copyright holder, or any other +organization that participated in or was otherwise related to the +development of the program and their respective officials, directors, +officers and other employees shall be held liable for any and all +damages, including, without limitation, general, special, incidental +and consequential damages, arising out of or otherwise in connection +with the use or inability to use the program or any product, material +or result produced or otherwise obtained by using the program, +regardless of whether they have been advised of, or otherwise had +knowledge of, the possibility of such damages at any time during the +project or thereafter. Each user will be deemed to have agreed to the +foregoing by his or her commencement of use of the program. The term +"use" as used herein includes, but is not limited to, the use, +modification, copying and distribution of the program and the +production of secondary products from the program. + +In the case where the program, whether in its original form or +modified, was distributed or delivered to or received by a user from +any person, organization or entity other than ICOT, unless it makes or +grants independently of ICOT any specific warranty to the user in +writing, such person, organization or entity, will also be exempted +from and not be held liable to the user for any such damages as noted +above as far as the program is concerned. + +The Derby build relies on a jar file supplied by the JSON Simple +project, hosted at https://code.google.com/p/json-simple/. +The JSON simple jar file is licensed under the Apache 2.0 License. + +Hive CLI +Copyright 2016 The Apache Software Foundation + +Hive JDBC +Copyright 2016 The Apache Software Foundation + + +Chill is a set of Scala extensions for Kryo. +Copyright 2012 Twitter, Inc. + +Third Party Dependencies: + +Kryo 2.17 +BSD 3-Clause License +http://code.google.com/p/kryo + +Commons-Codec 1.7 +Apache Public License 2.0 +http://hadoop.apache.org + + + +Breeze is distributed under an Apache License V2.0 (See LICENSE) + +=============================================================================== + +Proximal algorithms outlined in Proximal.scala (package breeze.optimize.proximal) +are based on https://github.com/cvxgrp/proximal (see LICENSE for details) and distributed with +Copyright (c) 2014 by Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +QuadraticMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2014, Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +NonlinearMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2015, Debasish Das (Verizon), all rights reserved. + + +stream-lib +Copyright 2016 AddThis + +This product includes software developed by AddThis. + +This product also includes code adapted from: + +Apache Solr (http://lucene.apache.org/solr/) +Copyright 2014 The Apache Software Foundation + +Apache Mahout (http://mahout.apache.org/) +Copyright 2014 The Apache Software Foundation diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f16..f52d785e05cdd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -13,6 +13,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html +SystemRequirements: Java (== 8) Depends: R (>= 3.0), methods diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 190c50ea10482..0fd08482c4413 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -117,6 +117,7 @@ exportMethods("arrange", "dropna", "dtypes", "except", + "exceptAll", "explain", "fillna", "filter", @@ -131,6 +132,7 @@ exportMethods("arrange", "hint", "insertInto", "intersect", + "intersectAll", "isLocal", "isStreaming", "join", @@ -201,6 +203,16 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_distinct", + "array_join", + "array_max", + "array_min", + "array_position", + "array_remove", + "array_repeat", + "array_sort", + "arrays_overlap", + "arrays_zip", "asc", "ascii", "asin", @@ -245,6 +257,7 @@ exportMethods("%<=>%", "decode", "dense_rank", "desc", + "element_at", "encode", "endsWith", "exp", @@ -254,6 +267,7 @@ exportMethods("%<=>%", "expr", "factorial", "first", + "flatten", "floor", "format_number", "format_string", @@ -296,6 +310,8 @@ exportMethods("%<=>%", "lower", "lpad", "ltrim", + "map_entries", + "map_from_arrays", "map_keys", "map_values", "max", @@ -346,6 +362,7 @@ exportMethods("%<=>%", "sinh", "size", "skewness", + "slice", "sort_array", "soundex", "spark_partition_id", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1c9495b0795e..4f2d4c7c002d4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -588,7 +588,7 @@ setMethod("cache", #' \url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}. #' #' @param x the SparkDataFrame to persist. -#' @param newLevel storage level chosen for the persistance. See available options in +#' @param newLevel storage level chosen for the persistence. See available options in #' the description. #' #' @family SparkDataFrame functions @@ -2297,6 +2297,8 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) +setClassUnion("numericOrColumn", c("numeric", "Column")) + #' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). @@ -2846,6 +2848,35 @@ setMethod("intersect", dataFrame(intersected) }) +#' intersectAll +#' +#' Return a new SparkDataFrame containing rows in both this SparkDataFrame +#' and another SparkDataFrame while preserving the duplicates. +#' This is equivalent to \code{INTERSECT ALL} in SQL. Also as standard in +#' SQL, this function resolves columns by position (not by name). +#' +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. +#' @return A SparkDataFrame containing the result of the intersect all operation. +#' @family SparkDataFrame functions +#' @aliases intersectAll,SparkDataFrame,SparkDataFrame-method +#' @rdname intersectAll +#' @name intersectAll +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' intersectAllDF <- intersectAll(df1, df2) +#' } +#' @note intersectAll since 2.4.0 +setMethod("intersectAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersectAll", y@sdf) + dataFrame(intersected) + }) + #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame @@ -2865,7 +2896,6 @@ setMethod("intersect", #' df2 <- read.json(path2) #' exceptDF <- except(df, df2) #' } -#' @rdname except #' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2874,6 +2904,35 @@ setMethod("except", dataFrame(excepted) }) +#' exceptAll +#' +#' Return a new SparkDataFrame containing rows in this SparkDataFrame +#' but not in another SparkDataFrame while preserving the duplicates. +#' This is equivalent to \code{EXCEPT ALL} in SQL. Also as standard in +#' SQL, this function resolves columns by position (not by name). +#' +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. +#' @return A SparkDataFrame containing the result of the except all operation. +#' @family SparkDataFrame functions +#' @aliases exceptAll,SparkDataFrame,SparkDataFrame-method +#' @rdname exceptAll +#' @name exceptAll +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' exceptAllDF <- exceptAll(df1, df2) +#' } +#' @note exceptAll since 2.4.0 +setMethod("exceptAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + excepted <- callJMethod(x@sdf, "exceptAll", y@sdf) + dataFrame(excepted) + }) + #' Save the contents of SparkDataFrame to a data source. #' #' The data source is specified by the \code{source} and a set of options (...). diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 429dd5d565492..c819a7d14ae98 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -351,7 +351,7 @@ setMethod("toDF", signature(x = "RDD"), read.json.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -421,7 +421,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { read.orc <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the ORC file path + # Allow the user to have a more flexible definition of the ORC file path path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -442,7 +442,7 @@ read.orc <- function(path, ...) { read.parquet.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the Parquet file path + # Allow the user to have a more flexible definition of the Parquet file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -492,7 +492,7 @@ parquetFile <- function(x, ...) { read.text.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 9d82814211bc5..660f0864403e0 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout) { +connectBackend <- function(hostname, port, timeout, authSecret) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) { con <- socketConnection(host = hostname, port = port, server = FALSE, blocking = TRUE, open = "wb", timeout = timeout) - + doServerAuth(con, authSecret) assign(".sparkRCon", con, envir = .sparkREnv) con } @@ -60,6 +60,47 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack combinedArgs } +checkJavaVersion <- function() { + javaBin <- "java" + javaHome <- Sys.getenv("JAVA_HOME") + javaReqs <- utils::packageDescription(utils::packageName(), fields = c("SystemRequirements")) + sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) + if (javaHome != "") { + javaBin <- file.path(javaHome, "bin", javaBin) + } + + # If java is missing from PATH, we get an error in Unix and a warning in Windows + javaVersionOut <- tryCatch( + if (is_windows()) { + # See SPARK-24535 + system2(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + } else { + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + }, + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) + javaVersionFilter <- Filter( + function(x) { + grepl(" version", x) + }, javaVersionOut) + + javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] + # javaVersionStr is of the form 1.8.0_92. + # Extract 8 from it to compare to sparkJavaVersion + javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) + if (javaVersionNum != sparkJavaVersion) { + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", + javaVersionStr)) + } + return(javaVersionNum) +} + launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { @@ -67,6 +108,7 @@ launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { } else { sparkSubmitBin <- sparkSubmitBinName } + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(launchScript(sparkSubmitBin, combinedArgs)) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8ec727dd042bc..f168ca76b6007 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -43,7 +43,7 @@ getMinPartitions <- function(sc, minPartitions) { #' lines <- textFile(sc, "myfile.txt") #'} textFile <- function(sc, path, minPartitions = NULL) { - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") @@ -71,7 +71,7 @@ textFile <- function(sc, path, minPartitions = NULL) { #' rdd <- objectFile(sc, "myfile") #'} objectFile <- function(sc, path, minPartitions = NULL) { - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") @@ -138,11 +138,10 @@ parallelize <- function(sc, coll, numSlices = 1) { sizeLimit <- getMaxAllocationLimit(sc) objectSize <- object.size(coll) + len <- length(coll) # For large objects we make sure the size of each slice is also smaller than sizeLimit - numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) - if (numSerializedSlices > length(coll)) - numSerializedSlices <- length(coll) + numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit))) # Generate the slice ids to put each row # For instance, for numSerializedSlices of 22, length of 50 @@ -153,8 +152,8 @@ parallelize <- function(sc, coll, numSlices = 1) { splits <- if (numSerializedSlices > 0) { unlist(lapply(0: (numSerializedSlices - 1), function(x) { # nolint start - start <- trunc((x * length(coll)) / numSerializedSlices) - end <- trunc(((x + 1) * length(coll)) / numSerializedSlices) + start <- trunc((as.numeric(x) * len) / numSerializedSlices) + end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices) # nolint end rep(start, end - start) })) @@ -305,6 +304,8 @@ setCheckpointDirSC <- function(sc, dirName) { #' Currently directories are only supported for Hadoop-supported filesystems. #' Refer Hadoop-supported filesystems at \url{https://wiki.apache.org/hadoop/HCFS}. #' +#' Note: A path can be added only once. Subsequent additions of the same path are ignored. +#' #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index a90f7d381026b..cb03f1667629f 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) { stop(paste("Unsupported type for deserialization", type))) } -readString <- function(con) { - stringLen <- readInt(con) - raw <- readBin(con, raw(), stringLen, endian = "big") +readStringData <- function(con, len) { + raw <- readBin(con, raw(), len, endian = "big") string <- rawToChar(raw) Encoding(string) <- "UTF-8" string } +readString <- function(con) { + stringLen <- readInt(con) + readStringData(con, stringLen) +} + readInt <- function(con) { readBin(con, integer(), n = 1, endian = "big") } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a527426b19674..2929a00330c62 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,9 +189,17 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param y Column to compute on. +#' @param value A value to compute on. +#' \itemize{ +#' \item \code{array_contains}: a value to be checked if contained in the column. +#' \item \code{array_position}: a value to locate in the given array. +#' \item \code{array_remove}: a value to remove in the given array. +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. +#' options as the JSON data source. In \code{arrays_zip}, this contains additional +#' Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -201,14 +209,24 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, slice(tmp$v1, 2L, 2L))) #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) -#' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3)))} +#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3))) +#' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5))) +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) +#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) +#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} NULL #' Window functions for Column operations @@ -796,6 +814,8 @@ setMethod("factorial", #' #' The function by default returns the first values it sees. It will return the first non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param na.rm a logical value indicating whether NA values should be stripped #' before the computation proceeds. @@ -939,6 +959,8 @@ setMethod("kurtosis", #' #' The function by default returns the last values it sees. It will return the last non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param x column to compute on. #' @param na.rm a logical value indicating whether NA values should be stripped @@ -1192,6 +1214,7 @@ setMethod("minute", #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. #' The method should be used with no argument. +#' Note: the function is non-deterministic because its result depends on partition IDs. #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method @@ -1245,9 +1268,9 @@ setMethod("quarter", }) #' @details -#' \code{reverse}: Reverses the string column and returns it as a new string column. +#' \code{reverse}: Returns a reversed string or an array with reverse order of elements. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases reverse reverse,Column-method #' @note reverse since 1.5.0 setMethod("reverse", @@ -1898,6 +1921,7 @@ setMethod("atan2", signature(y = "Column"), #' @details #' \code{datediff}: Returns the number of days from \code{y} to \code{x}. +#' If \code{y} is later than \code{x} then the result is positive. #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method @@ -1958,6 +1982,9 @@ setMethod("levenshtein", signature(y = "Column"), #' @details #' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} +#' are on the same day of month, or both are the last day of month, time of day will be ignored. +#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method @@ -2036,20 +2063,10 @@ setMethod("countDistinct", #' @details #' \code{concat}: Concatenates multiple input columns together into a single column. -#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. +#' The function works with strings, binary and compatible array columns. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases concat concat,Column-method -#' @examples -#' -#' \dontrun{ -#' # concatenate strings -#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), -#' s2 = concat(df$Class, df$Sex, df$Age), -#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), -#' s4 = concat_ws("_", df$Class, df$Sex), -#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) -#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2390,6 +2407,13 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat_ws("_", df$Class, df$Sex), +#' s2 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2575,6 +2599,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @details #' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) #' samples from U[0.0, 1.0]. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2603,6 +2628,7 @@ setMethod("rand", signature(seed = "numeric"), #' @details #' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples #' from the standard normal distribution. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -2975,7 +3001,6 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @note array_contains since 1.6.0 @@ -2986,6 +3011,204 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_distinct}: Removes duplicate values from the array. +#' +#' @rdname column_collection_functions +#' @aliases array_distinct array_distinct,Column-method +#' @note array_distinct since 2.4.0 +setMethod("array_distinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_distinct", x@jc) + column(jc) + }) + +#' @details +#' \code{array_join}: Concatenates the elements of column using the delimiter. +#' Null values are replaced with nullReplacement if set, otherwise they are ignored. +#' +#' @param delimiter a character string that is used to concatenate the elements of column. +#' @param nullReplacement an optional character string that is used to replace the Null values. +#' @rdname column_collection_functions +#' @aliases array_join array_join,Column-method +#' @note array_join since 2.4.0 +setMethod("array_join", + signature(x = "Column", delimiter = "character"), + function(x, delimiter, nullReplacement = NULL) { + jc <- if (is.null(nullReplacement)) { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter) + } else { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter, + as.character(nullReplacement)) + } + column(jc) + }) + +#' @details +#' \code{array_max}: Returns the maximum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_max array_max,Column-method +#' @note array_max since 2.4.0 +setMethod("array_max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_max", x@jc) + column(jc) + }) + +#' @details +#' \code{array_min}: Returns the minimum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_min array_min,Column-method +#' @note array_min since 2.4.0 +setMethod("array_min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_min", x@jc) + column(jc) + }) + +#' @details +#' \code{array_position}: Locates the position of the first occurrence of the given value +#' in the given array. Returns NA if either of the arguments are NA. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the given +#' value could not be found in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_position array_position,Column-method +#' @note array_position since 2.4.0 +setMethod("array_position", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value) + column(jc) + }) + +#' @details +#' \code{array_remove}: Removes all elements that equal to element from the given array. +#' +#' @rdname column_collection_functions +#' @aliases array_remove array_remove,Column-method +#' @note array_remove since 2.4.0 +setMethod("array_remove", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_remove", x@jc, value) + column(jc) + }) + +#' @details +#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times +#' given by \code{count}. +#' +#' @param count a Column or constant determining the number of repetitions. +#' @rdname column_collection_functions +#' @aliases array_repeat array_repeat,Column,numericOrColumn-method +#' @note array_repeat since 2.4.0 +setMethod("array_repeat", + signature(x = "Column", count = "numericOrColumn"), + function(x, count) { + if (class(count) == "Column") { + count <- count@jc + } else { + count <- as.integer(count) + } + jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count) + column(jc) + }) + +#' @details +#' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array +#' must be orderable. NA elements will be placed at the end of the returned array. +#' +#' @rdname column_collection_functions +#' @aliases array_sort array_sort,Column-method +#' @note array_sort since 2.4.0 +setMethod("array_sort", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc) + column(jc) + }) + +#' @details +#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in +#' common. If not and both arrays are non-empty and any of them contains a null, it returns null. +#' It returns false otherwise. +#' +#' @rdname column_collection_functions +#' @aliases arrays_overlap arrays_overlap,Column-method +#' @note arrays_overlap since 2.4.0 +setMethod("arrays_overlap", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc) + column(jc) + }) + +#' @details +#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th +#' values of input arrays. +#' +#' @rdname column_collection_functions +#' @aliases arrays_zip arrays_zip,Column-method +#' @note arrays_zip since 2.4.0 +setMethod("arrays_zip", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(arg) { + stopifnot(class(arg) == "Column") + arg@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_zip", jcols) + column(jc) + }) + +#' @details +#' \code{flatten}: Creates a single array from an array of arrays. +#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. +#' +#' @rdname column_collection_functions +#' @aliases flatten flatten,Column-method +#' @note flatten since 2.4.0 +setMethod("flatten", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "flatten", x@jc) + column(jc) + }) + +#' @details +#' \code{map_entries}: Returns an unordered array of all entries in the given map. +#' +#' @rdname column_collection_functions +#' @aliases map_entries map_entries,Column-method +#' @note map_entries since 2.4.0 +setMethod("map_entries", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc) + column(jc) + }) + +#' @details +#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for +#' keys. The array in the second column is used for values. All elements in the array for key +#' should not be null. +#' +#' @rdname column_collection_functions +#' @aliases map_from_arrays map_from_arrays,Column-method +#' @note map_from_arrays since 2.4.0 +setMethod("map_from_arrays", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_from_arrays", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' @@ -3012,6 +3235,22 @@ setMethod("map_values", column(jc) }) +#' @details +#' \code{element_at}: Returns element of array at given index in \code{extraction} if +#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. +#' Note: The position is not zero based, but 1 based index. +#' +#' @param extraction index to check for in array or key to check for in map +#' @rdname column_collection_functions +#' @aliases element_at element_at,Column-method +#' @note element_at since 2.4.0 +setMethod("element_at", + signature(x = "Column", extraction = "ANY"), + function(x, extraction) { + jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction) + column(jc) + }) + #' @details #' \code{explode}: Creates a new row for each element in the given array or map column. #' @@ -3039,8 +3278,25 @@ setMethod("size", }) #' @details -#' \code{sort_array}: Sorts the input array in ascending or descending order according -#' to the natural ordering of the array elements. +#' \code{slice}: Returns an array containing all the elements in x from the index start +#' (or starting from the end if start is negative) with the specified length. +#' +#' @rdname column_collection_functions +#' @param start an index indicating the first element occurring in the result. +#' @param length a number of consecutive elements chosen to the result. +#' @aliases slice slice,Column-method +#' @note slice since 2.4.0 +setMethod("slice", + signature(x = "Column"), + function(x, start, length) { + jc <- callJStatic("org.apache.spark.sql.functions", "slice", x@jc, start, length) + column(jc) + }) + +#' @details +#' \code{sort_array}: Sorts the input array in ascending or descending order according to +#' the natural ordering of the array elements. NA elements will be placed at the beginning of +#' the returned array in ascending order or at the end of the returned array in descending order. #' #' @rdname column_collection_functions #' @param asc a logical flag indicating the sorting order. @@ -3109,6 +3365,8 @@ setMethod("create_map", #' @details #' \code{collect_list}: Creates a list of objects with duplicates. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method @@ -3128,6 +3386,8 @@ setMethod("collect_list", #' @details #' \code{collect_set}: Creates a list of objects with duplicate elements eliminated. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 974beff1a3d76..f6f1849787a23 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -471,6 +471,9 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname exceptAll +setGeneric("exceptAll", function(x, y) { standardGeneric("exceptAll") }) + #' @rdname nafunctions setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) @@ -495,6 +498,9 @@ setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertIn #' @rdname intersect setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) +#' @rdname intersectAll +setGeneric("intersectAll", function(x, y) { standardGeneric("intersectAll") }) + #' @rdname isLocal setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) @@ -624,7 +630,7 @@ setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) +setGeneric("toJSON", function(x, ...) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) @@ -757,6 +763,46 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_max", function(x) { standardGeneric("array_max") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_min", function(x) { standardGeneric("array_min") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_remove", function(x, value) { standardGeneric("array_remove") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -801,7 +847,7 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) @@ -886,6 +932,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") }) + #' @rdname column_string_functions #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) @@ -902,6 +952,10 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("flatten", function(x) { standardGeneric("flatten") }) + #' @rdname column_datetime_diff_functions #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) @@ -1010,6 +1064,14 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) @@ -1110,7 +1172,7 @@ setGeneric("regexp_replace", #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) @@ -1170,6 +1232,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("slice", function(x, start, length) { standardGeneric("slice") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 6769be038efa9..0e60842dd44c8 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -362,7 +362,18 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' For regression, must be "variance". For classification, must be one of #' "entropy" and "gini", default is "gini". #' @param featureSubsetStrategy The number of features to consider for splits at each tree node. -#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' Supported options: "auto" (choose automatically for task: If +#' numTrees == 1, set to "all." If numTrees > 1 +#' (forest), set to "sqrt" for classification and +#' to "onethird" for regression), +#' "all" (use all features), +#' "onethird" (use 1/3 of the features), +#' "sqrt" (use sqrt(number of features)), +#' "log2" (use log2(number of features)), +#' "n": (when n is in the range (0, 1.0], use +#' n * number of features. When n is in the range +#' (1, number of features), use n features). +#' Default is "auto". #' @param seed integer seed for random number generation. #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in #' range (0, 1]. diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index a480ac606f10d..d3a9cbae7d808 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -158,11 +158,16 @@ sparkR.sparkContext <- function( " please use the --packages commandline instead", sep = ",")) } backendPort <- existingPort + authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET") + if (nchar(authSecret) == 0) { + stop("Auth secret not provided in environment.") + } } else { path <- tempfile(pattern = "backend_port") submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) + invisible(checkJavaVersion()) launchBackend( args = path, sparkHome = sparkHome, @@ -186,16 +191,27 @@ sparkR.sparkContext <- function( monitorPort <- readInt(f) rLibPath <- readString(f) connectionTimeout <- readInt(f) + + # Don't use readString() so that we can provide a useful + # error message if the R and Java versions are mismatched. + authSecretLen <- readInt(f) + if (length(authSecretLen) == 0 || authSecretLen == 0) { + stop("Unexpected EOF in JVM connection data. Mismatched versions?") + } + authSecret <- readStringData(f, authSecretLen) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || length(monitorPort) == 0 || monitorPort == 0 || - length(rLibPath) != 1) { + length(rLibPath) != 1 || length(authSecret) == 0) { stop("JVM failed to launch") } - assign(".monitorConn", - socketConnection(port = monitorPort, timeout = connectionTimeout), - envir = .sparkREnv) + + monitorConn <- socketConnection(port = monitorPort, blocking = TRUE, + timeout = connectionTimeout, open = "wb") + doServerAuth(monitorConn, authSecret) + + assign(".monitorConn", monitorConn, envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -205,7 +221,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort, timeout = connectionTimeout) + connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret) }, error = function(err) { stop("Failed to connect JVM\n") @@ -687,3 +703,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) { NULL } } + +# Utility function for sending auth data over a socket and checking the server's reply. +doServerAuth <- function(con, authSecret) { + if (nchar(authSecret) == 0) { + stop("Auth secret not provided.") + } + writeString(con, authSecret) + flush(con) + reply <- readString(con) + if (reply != "ok") { + close(con) + stop("Unexpected reply from server.") + } +} diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index fc83463f72cd4..5eccbdc9d3818 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -163,7 +163,7 @@ setMethod("isActive", #' #' @param x a StreamingQuery. #' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} -#' is called or an error has occured. +#' is called or an error has occurred. #' @return TRUE if query has terminated within the timeout period; nothing if timeout is not #' specified. #' @rdname awaitTermination diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index f1b5ecaa017df..c3501977e64bc 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -746,7 +746,7 @@ varargsToJProperties <- function(...) { props } -launchScript <- function(script, combinedArgs, wait = FALSE) { +launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr = "") { if (.Platform$OS.type == "windows") { scriptWithArgs <- paste(script, combinedArgs, sep = " ") # on Windows, intern = F seems to mean output to the console. (documentation on this is missing) @@ -756,7 +756,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE) { # stdout = F means discard output # stdout = "" means to its console (default) # Note that the console of this child process might not be the same as the running R process. - system2(script, combinedArgs, stdout = "", wait = wait) + system2(script, combinedArgs, stdout = stdout, wait = wait, stderr = stderr) } } diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index 823d26f12feee..80df3d8ce6e59 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,6 +18,10 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { + tryCatch(checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") }) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) @@ -50,6 +54,10 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { + tryCatch(checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") }) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 2e31dc5f728cd..fb9db63b07cd0 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) + port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout) + +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # Waits indefinitely for a socket connecion by default. selectTimeout <- NULL diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 00789d815bba8..c2adf613acb02 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -62,7 +62,7 @@ compute <- function(mode, partition, serializer, deserializer, key, # Transform the result data.frame back to a list of rows output <- split(output, seq(nrow(output))) } else { - # Serialize the ouput to a byte array + # Serialize the output to a byte array stopifnot(serializer == "byte") } } else { @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) + outputCon <- socketConnection( port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index f0d0a5114f89f..288a2714a554e 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -240,3 +240,10 @@ test_that("add and get file to be downloaded with Spark job on every node", { unlink(path, recursive = TRUE) sparkR.session.stop() }) + +test_that("SPARK-25234: parallelize should not have integer overflow", { + sc <- sparkR.sparkContext(master = sparkRTestMaster) + # 47000 * 47000 exceeds integer range + parallelize(sc, 1:47000, 47000) + sparkR.session.stop() +}) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index a46c47dccd02e..023686e75d50a 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -382,10 +382,10 @@ test_that("spark.mlp", { trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) traindf <- as.DataFrame(data[trainidxs, ]) testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) - model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3)) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 2)) predictions <- predict(model, testdf) expect_error(collect(predictions)) - model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip") + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 2), handleInvalid = "skip") predictions <- predict(model, testdf) expect_equal(class(collect(predictions)$clicked[1]), "list") diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7105469ffc242..e1f3cf339e83f 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -734,8 +734,8 @@ test_that("test cache, uncache and clearCache", { clearCache() expect_true(dropTempView("table1")) - expect_error(uncacheTable("foo"), - "Error in uncacheTable : analysis error - Table or view not found: foo") + expect_error(uncacheTable("zxwtyswklpf"), + "Error in uncacheTable : analysis error - Table or view not found: zxwtyswklpf") }) test_that("insertInto() on a registered table", { @@ -1479,24 +1479,125 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains() and sort_array() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() and reverse() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_max(df[[1]])))[[1]] + expect_equal(result, c(3, 6)) + + result <- collect(select(df, array_min(df[[1]])))[[1]] + expect_equal(result, c(1, 4)) + + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 0)) + + result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 6)) + + result <- collect(select(df, reverse(df[[1]])))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(4L, 5L, 6L))) + + df2 <- createDataFrame(list(list("abc"))) + result <- collect(select(df2, reverse(df2[[1]])))[[1]] + expect_equal(result, "cba") + + # Test array_distinct() and array_remove() + df <- createDataFrame(list(list(list(1L, 2L, 3L, 1L, 2L)), list(list(6L, 5L, 5L, 4L, 6L)))) + result <- collect(select(df, array_distinct(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(6L, 5L, 4L))) + + result <- collect(select(df, array_remove(df[[1]], 2L)))[[1]] + expect_equal(result, list(list(1L, 3L, 1L), list(6L, 5L, 5L, 4L, 6L))) + + # Test arrays_zip() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 4L))), schema = c("c1", "c2")) + result <- collect(select(df, arrays_zip(df[[1]], df[[2]])))[[1]] + expected_entries <- list(listToStruct(list(c1 = 1L, c2 = 3L)), + listToStruct(list(c1 = 2L, c2 = 4L))) + expect_equal(result, list(expected_entries)) + + # Test map_from_arrays() + df <- createDataFrame(list(list(list("x", "y"), list(1, 2))), schema = c("k", "v")) + result <- collect(select(df, map_from_arrays(df$k, df$v)))[[1]] + expected_entries <- list(as.environment(list(x = 1, y = 2))) + expect_equal(result, expected_entries) + + # Test array_repeat() + df <- createDataFrame(list(list("a", 3L), list("b", 2L))) + result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list("a", "a", "a"), list("b", "b"))) + + result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]] + expect_equal(result, list(list("a", "a"), list("b", "b"))) + + # Test arrays_overlap() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, NA), list(3L, 4L)))) + result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] + expect_equal(result, c(TRUE, FALSE, NA)) + + # Test array_join() + df <- createDataFrame(list(list(list("Hello", "World!")))) + result <- collect(select(df, array_join(df[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df2 <- createDataFrame(list(list(list("Hello", NA, "World!")))) + result <- collect(select(df2, array_join(df2[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df2, array_join(df2[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df3 <- createDataFrame(list(list(list("Hello", NULL, "World!")))) + result <- collect(select(df3, array_join(df3[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df3, array_join(df3[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + + # Test array_sort() and sort_array() + df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) + + result <- collect(select(df, array_sort(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, NA), list(4L, 5L, 6L, NA, NA))) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] - expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA))) result <- collect(select(df, sort_array(df[[1]])))[[1]] - expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) - - # Test map_keys() and map_values() + expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) + + # Test slice() + df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(4L, 5L)))) + result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] + expect_equal(result, list(list(2L, 3L), list(5L))) + + # Test concat() + df <- createDataFrame(list(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + list(list(7L, 8L, 9L), list(10L, 11L, 12L)))) + result <- collect(select(df, concat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L, 5L, 6L), list(7L, 8L, 9L, 10L, 11L, 12L))) + + # Test flatten() + df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), + list(list(list(5L, 6L), list(7L, 8L))))) + result <- collect(select(df, flatten(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) + + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) + result <- collect(select(df, map_entries(df$map)))[[1]] + expected_entries <- list(listToStruct(list(key = "x", value = 1)), + listToStruct(list(key = "y", value = 2))) + expect_equal(result, list(expected_entries)) + result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) result <- collect(select(df, map_values(df$map)))[[1]] expect_equal(result, list(list(1, 2))) + result <- collect(select(df, element_at(df$map, "y")))[[1]] + expect_equal(result, 2) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) @@ -2188,8 +2289,8 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { expect_equal(count(where(join(df, df2), df$name == df2$name)), 3) # cartesian join expect_error(tryCatch(count(join(df, df2)), error = function(e) { stop(e) }), - paste0(".*(org.apache.spark.sql.AnalysisException: Detected cartesian product for", - " INNER join between logical plans).*")) + paste0(".*(org.apache.spark.sql.AnalysisException: Detected implicit cartesian", + " product for INNER join between logical plans).*")) joined <- crossJoin(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -2381,6 +2482,25 @@ test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataF unlink(jsonPath2) }) +test_that("intersectAll() and exceptAll()", { + df1 <- createDataFrame(list(list("a", 1), list("a", 1), list("a", 1), + list("a", 1), list("b", 3), list("c", 4)), + schema = c("a", "b")) + df2 <- createDataFrame(list(list("a", 1), list("a", 1), list("b", 3)), schema = c("a", "b")) + intersectAllExpected <- data.frame("a" = c("a", "a", "b"), "b" = c(1, 1, 3), + stringsAsFactors = FALSE) + exceptAllExpected <- data.frame("a" = c("a", "a", "c"), "b" = c(1, 1, 4), + stringsAsFactors = FALSE) + intersectAllDf <- arrange(intersectAll(df1, df2), df1$a) + expect_is(intersectAllDf, "SparkDataFrame") + exceptAllDf <- arrange(exceptAll(df1, df2), df1$a) + expect_is(exceptAllDf, "SparkDataFrame") + intersectAllActual <- collect(intersectAllDf) + expect_identical(intersectAllActual, intersectAllExpected) + exceptAllActual <- collect(exceptAllDf) + expect_identical(exceptAllActual, exceptAllExpected) +}) + test_that("withColumn() and withColumnRenamed()", { df <- read.json(jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) @@ -3512,11 +3632,11 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { expect_equal(currentDatabase(), "default") expect_error(setCurrentDatabase("default"), NA) - expect_error(setCurrentDatabase("foo"), - "Error in setCurrentDatabase : analysis error - Database 'foo' does not exist") + expect_error(setCurrentDatabase("zxwtyswklpf"), + "Error in setCurrentDatabase : analysis error - Database 'zxwtyswklpf' does not exist") dbs <- collect(listDatabases()) expect_equal(names(dbs), c("name", "description", "locationUri")) - expect_equal(dbs[[1]], "default") + expect_equal(which(dbs[, 1] == "default"), 1) }) test_that("catalog APIs, listTables, listColumns, listFunctions", { @@ -3539,8 +3659,9 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { expect_equal(colnames(c), c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) expect_equal(collect(c)[[1]][[1]], "speed") - expect_error(listColumns("foo", "default"), - "Error in listColumns : analysis error - Table 'foo' does not exist in database 'default'") + expect_error(listColumns("zxwtyswklpf", "default"), + paste("Error in listColumns : analysis error - Table", + "'zxwtyswklpf' does not exist in database 'default'")) f <- listFunctions() expect_true(nrow(f) >= 200) # 250 @@ -3548,8 +3669,9 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { c("name", "database", "description", "className", "isTemporary")) expect_equal(take(orderBy(f, "className"), 1)$className, "org.apache.spark.sql.catalyst.expressions.Abs") - expect_error(listFunctions("foo_db"), - "Error in listFunctions : analysis error - Database 'foo_db' does not exist") + expect_error(listFunctions("zxwtyswklpf_db"), + paste("Error in listFunctions : analysis error - Database", + "'zxwtyswklpf_db' does not exist")) # recoverPartitions does not work with tempory view expect_error(recoverPartitions("cars"), diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R index f0292ab335592..b2b6f34aaa085 100644 --- a/R/pkg/tests/fulltests/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -103,7 +103,7 @@ test_that("cleanClosure on R functions", { expect_true("l" %in% ls(env)) expect_true("f" %in% ls(env)) expect_equal(get("l", envir = env, inherits = FALSE), l) - # "y" should be in the environemnt of g. + # "y" should be in the environment of g. newG <- get("g", envir = env, inherits = FALSE) env <- environment(newG) expect_equal(length(ls(env)), 1) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d4713de7806a1..090363c5f8a3e 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -590,6 +590,7 @@ summary(model) Predict values on training data ```{r} prediction <- predict(model, training) +head(select(prediction, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Logistic Regression @@ -613,6 +614,7 @@ summary(model) Predict values on training data ```{r} fitted <- predict(model, training) +head(select(fitted, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` Multinomial logistic regression against three classes @@ -652,7 +654,7 @@ We use Titanic data set to show how to use `spark.mlp` in classification. t <- as.data.frame(Titanic) training <- createDataFrame(t) # fit a Multilayer Perceptron Classification Model -model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 0, 5, 5, 5, 9, 9, 9)) +model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 2), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 5, 5, 9, 9)) ``` To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. @@ -807,6 +809,7 @@ df <- createDataFrame(t) dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2) summary(dtModel) predictions <- predict(dtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Gradient-Boosted Trees @@ -822,6 +825,7 @@ df <- createDataFrame(t) gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2) summary(gbtModel) predictions <- predict(gbtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Random Forest @@ -837,6 +841,7 @@ df <- createDataFrame(t) rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2) summary(rfModel) predictions <- predict(rfModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Bisecting k-Means diff --git a/README.md b/README.md index 1e521a7e7b178..531d330234062 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,8 @@ can be run using: Please see the guidance on how to [run tests for a module, or individual tests](http://spark.apache.org/developer-tools.html#individual-tests). +There is also a Kubernetes integration test, see resource-managers/kubernetes/integration-tests/README.md + ## A Note About Hadoop Versions Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported diff --git a/appveyor.yml b/appveyor.yml index aee94c59612d2..7fb45745a036f 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,7 +48,7 @@ install: - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" build_script: - - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package + - cmd: mvn -DskipTests -Psparkr -Phive package environment: NOT_CRAN: true diff --git a/assembly/README b/assembly/README index d5dafab477410..affd281a1385c 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=2.7.3 + -Dhadoop.version=2.7.7 diff --git a/assembly/pom.xml b/assembly/pom.xml index a207dae5a74ff..9608c96fd5369 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -254,6 +254,14 @@ spark-hadoop-cloud_${scala.binary.version} ${project.version} + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index f090240065bf1..d6371051ef7fb 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -49,6 +49,7 @@ function build { # Set image build arguments accordingly if this is a source repo and not a distribution archive. IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles BUILD_ARGS=( + ${BUILD_PARAMS} --build-arg img_path=$IMG_PATH --build-arg @@ -57,22 +58,38 @@ function build { else # Not passed as an argument to docker, but used to validate the Spark directory. IMG_PATH="kubernetes/dockerfiles" - BUILD_ARGS=() + BUILD_ARGS=(${BUILD_PARAMS}) fi if [ ! -d "$IMG_PATH" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi + local BINDING_BUILD_ARGS=( + ${BUILD_PARAMS} + --build-arg + base_img=$(image_ref spark) + ) + local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} + local RDOCKERFILE=${RDOCKERFILE:-"$IMG_PATH/spark/bindings/R/Dockerfile"} + + docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ + -t $(image_ref spark) \ + -f "$BASEDOCKERFILE" . - local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-py) \ + -f "$PYDOCKERFILE" . - docker build "${BUILD_ARGS[@]}" \ - -t $(image_ref spark) \ - -f "$DOCKERFILE" . + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-r) \ + -f "$RDOCKERFILE" . } function push { docker push "$(image_ref spark)" + docker push "$(image_ref spark-py)" + docker push "$(image_ref spark-r)" } function usage { @@ -86,10 +103,15 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: - -f file Dockerfile to build. By default builds the Dockerfile shipped with Spark. - -r repo Repository address. - -t tag Tag to apply to the built image, or to identify the image to be pushed. - -m Use minikube's Docker daemon. + -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. + -p file Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. + -R file Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + -r repo Repository address. + -t tag Tag to apply to the built image, or to identify the image to be pushed. + -m Use minikube's Docker daemon. + -n Build docker image with --no-cache + -b arg Build arg to build or push the image. For multiple build args, this option needs to + be used separately for each build arg. Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -116,14 +138,22 @@ fi REPO= TAG= -DOCKERFILE= -while getopts f:mr:t: option +BASEDOCKERFILE= +PYDOCKERFILE= +RDOCKERFILE= +NOCACHEARG= +BUILD_PARAMS= +while getopts f:p:R:mr:t:n:b: option do case "${option}" in - f) DOCKERFILE=${OPTARG};; + f) BASEDOCKERFILE=${OPTARG};; + p) PYDOCKERFILE=${OPTARG};; + R) RDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; + n) NOCACHEARG="--no-cache";; + b) BUILD_PARAMS=${BUILD_PARAMS}" --build-arg "${OPTARG};; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." diff --git a/bin/pyspark b/bin/pyspark index dd286277c1fc1..5d5affb1f97c3 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option -# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython +# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython # to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver # (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython # and executor Python executables. # Fail noisily if removed options are set if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then - echo "Error in pyspark startup:" + echo "Error in pyspark startup:" echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead." exit 1 fi @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 663670f2fddaf..15fa910c277b3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/build/mvn b/build/mvn index efa4f9364ea52..ae4276dbc7e32 100755 --- a/build/mvn +++ b/build/mvn @@ -93,7 +93,7 @@ install_mvn() { install_zinc() { local zinc_path="zinc-0.3.15/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/zinc/0.3.15" \ @@ -109,7 +109,7 @@ install_scala() { # determine the Scala version used in Spark local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/scala/${scala_version}" \ @@ -154,4 +154,4 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" +"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 0e491efac9181..58e2a8f25f34f 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -234,7 +234,7 @@ public void close() throws IOException { * Closes the given iterator if the DB is still open. Trying to close a JNI LevelDB handle * with a closed DB can cause JVM crashes, so this ensures that situation does not happen. */ - void closeIterator(LevelDBIterator it) throws IOException { + void closeIterator(LevelDBIterator it) throws IOException { synchronized (this._db) { DB _db = this._db.get(); if (_db != null) { diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java index 510b3058a4e3c..9abf26f02f7a7 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java @@ -35,7 +35,7 @@ public void testObjectWriteReadDelete() throws Exception { try { store.read(CustomType1.class, t.key); - fail("Expected exception for non-existant object."); + fail("Expected exception for non-existent object."); } catch (NoSuchElementException nsee) { // Expected. } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index b8123ac81d29a..205f7df87c5bc 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -80,7 +80,7 @@ public void testObjectWriteReadDelete() throws Exception { try { db.read(CustomType1.class, t.key); - fail("Expected exception for non-existant object."); + fail("Expected exception for non-existent object."); } catch (NoSuchElementException nsee) { // Expected. } diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 8b8f9892847c3..45fee541a4f5d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -77,16 +77,16 @@ public ByteBuffer nioByteBuffer() throws IOException { return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); } } catch (IOException e) { + String errorMessage = "Error in reading " + this; try { if (channel != null) { long size = channel.size(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); + errorMessage = "Error in reading " + this + " (actual file length " + size + ")"; } } catch (IOException ignored) { // ignore } - throw new IOException("Error in opening " + this, e); + throw new IOException(errorMessage, e); } finally { JavaUtils.closeQuietly(channel); } @@ -95,26 +95,24 @@ public ByteBuffer nioByteBuffer() throws IOException { @Override public InputStream createInputStream() throws IOException { FileInputStream is = null; + boolean shouldClose = true; try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return new LimitedInputStream(is, length); + InputStream r = new LimitedInputStream(is, length); + shouldClose = false; + return r; } catch (IOException e) { - try { - if (is != null) { - long size = file.length(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); - } - } catch (IOException ignored) { - // ignore - } finally { + String errorMessage = "Error in reading " + this; + if (is != null) { + long size = file.length(); + errorMessage = "Error in reading " + this + " (actual file length " + size + ")"; + } + throw new IOException(errorMessage, e); + } finally { + if (shouldClose) { JavaUtils.closeQuietly(is); } - throw new IOException("Error in opening " + this, e); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(is); - throw e; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java new file mode 100644 index 0000000000000..bd173b653e33e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +public interface StreamCallbackWithID extends StreamCallback { + String getID(); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..f3eb744ff7345 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -22,22 +22,24 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.TransportFrameDecoder; /** * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +public class StreamInterceptor implements TransportFrameDecoder.Interceptor { - private final TransportResponseHandler handler; + private final MessageHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private long bytesRead; - StreamInterceptor( - TransportResponseHandler handler, + public StreamInterceptor( + MessageHandler handler, String streamId, long byteCount, StreamCallback callback) { @@ -50,16 +52,24 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } + private void deactivateStream() { + if (handler instanceof TransportResponseHandler) { + // we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches + // (there is no extra cleanup that needs to happen) + ((TransportResponseHandler) handler).deactivateStream(); + } + } + @Override public boolean handle(ByteBuf buf) throws Exception { int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); @@ -72,10 +82,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); - handler.deactivateStream(); + deactivateStream(); throw re; } else if (bytesRead == byteCount) { - handler.deactivateStream(); + deactivateStream(); callback.onComplete(streamId); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..20d840baeaf6c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -32,15 +32,15 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.*; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -133,34 +133,21 @@ public void fetchChunk( long streamId, int chunkIndex, ChunkReceivedCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); - - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", streamChunkId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); + StdChannelListener listener = new StdChannelListener(streamChunkId) { + @Override + void handleFailure(String errorMsg, Throwable cause) { handler.removeFetchRequest(streamChunkId); - channel.close(); - try { - callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } + callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); } - }); + }; + handler.addFetchRequest(streamChunkId, callback); + + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener); } /** @@ -170,7 +157,12 @@ public void fetchChunk( * @param callback Object to call with the stream data. */ public void stream(String streamId, StreamCallback callback) { - long startTime = System.currentTimeMillis(); + StdChannelListener listener = new StdChannelListener(streamId) { + @Override + void handleFailure(String errorMsg, Throwable cause) throws Exception { + callback.onFailure(streamId, new IOException(errorMsg, cause)); + } + }; if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); } @@ -180,25 +172,7 @@ public void stream(String streamId, StreamCallback callback) { // when responses arrive. synchronized (this) { handler.addStreamCallback(streamId, callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request for {} to {} took {} ms", streamId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - callback.onFailure(streamId, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener); } } @@ -211,35 +185,44 @@ public void stream(String streamId, StreamCallback callback) { * @return The RPC's id. */ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isTraceEnabled()) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } - long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) - .addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", requestId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(requestId); - channel.close(); - try { - callback.onFailure(new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + .addListener(listener); + + return requestId; + } + + /** + * Send data to the remote end as a stream. This differs from stream() in that this is a request + * to *send* data to the remote end, not to receive it from the remote. + * + * @param meta meta data associated with the stream, which will be read completely on the + * receiving end before the stream itself. + * @param data this will be streamed to the remote end to allow for transferring large amounts + * of data without reading into memory. + * @param callback handles the reply -- onSuccess will only be called when both message and data + * are received successfully. + */ + public long uploadStream( + ManagedBuffer meta, + ManagedBuffer data, + RpcResponseCallback callback) { + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } + + long requestId = requestId(); + handler.addRpcRequest(requestId, callback); + + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener); return requestId; } @@ -319,4 +302,60 @@ public String toString() { .add("isActive", isActive()) .toString(); } + + private static long requestId() { + return Math.abs(UUID.randomUUID().getLeastSignificantBits()); + } + + private class StdChannelListener + implements GenericFutureListener> { + final long startTime; + final Object requestId; + + StdChannelListener(Object requestId) { + this.startTime = System.currentTimeMillis(); + this.requestId = requestId; + } + + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + if (logger.isTraceEnabled()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + handleFailure(errorMsg, future.cause()); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + + void handleFailure(String errorMsg, Throwable cause) throws Exception {} + } + + private class RpcChannelListener extends StdChannelListener { + final long rpcRequestId; + final RpcResponseCallback callback; + + RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + super("RPC " + rpcRequestId); + this.rpcRequestId = rpcRequestId; + this.callback = callback; + } + + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeRpcRequest(rpcRequestId); + callback.onFailure(new IOException(errorMsg, cause)); + } + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 7a3d96ceaef0c..596b0ea5dba9b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -212,8 +212,8 @@ public void handle(ResponseMessage message) throws Exception { if (entry != null) { StreamCallback callback = entry.getValue(); if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); + StreamInterceptor interceptor = new StreamInterceptor<>( + this, resp.streamId, resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8a6e3858081bf..fb44dbbb0953b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -29,6 +29,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; @@ -149,6 +150,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index e04524dde0a75..b64e4b7a970b5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -240,7 +240,7 @@ public boolean release(int decrement) { @Override public long transferTo(WritableByteChannel target, long position) throws IOException { - Preconditions.checkArgument(position == transfered(), "Invalid position."); + Preconditions.checkArgument(position == transferred(), "Invalid position."); do { if (currentEncrypted == null) { @@ -267,7 +267,7 @@ private void encryptMore() throws IOException { int copied = byteRawChannel.write(buf.nioBuffer()); buf.skipBytes(copied); } else { - region.transferTo(byteRawChannel, region.transfered()); + region.transferTo(byteRawChannel, region.transferred()); } cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); cos.flush(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..0ccd70c03aba8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,7 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), User(-1); + OneWayMessage(9), UploadStream(10), User(-1); private final byte id; @@ -65,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case 10: return UploadStream; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..bf80aed0afe10 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -80,6 +80,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case StreamFailure: return StreamFailure.decode(in); + case UploadStream: + return UploadStream.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index a5337656cbd84..b81c25afc737f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -137,30 +137,31 @@ protected void deallocate() { } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - ByteBuffer buffer = buf.nioBuffer(); - int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - target.write(buffer) : writeNioBuffer(target, buffer); + // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance + // for the case that the passed-in buffer has too many components. + int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); + // If the ByteBuf holds more then one ByteBuffer we should better call nioBuffers(...) + // to eliminate extra memory copies. + int written = 0; + if (buf.nioBufferCount() == 1) { + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + written = target.write(buffer); + } else { + ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length); + for (ByteBuffer buffer: buffers) { + int remaining = buffer.remaining(); + int w = target.write(buffer); + written += w; + if (w < remaining) { + // Could not write all, we need to break now. + break; + } + } + } buf.skipBytes(written); return written; } - private int writeNioBuffer( - WritableByteChannel writeCh, - ByteBuffer buf) throws IOException { - int originalLimit = buf.limit(); - int ret = 0; - - try { - int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); - buf.limit(buf.position() + ioSize); - ret = writeCh.write(buf); - } finally { - buf.limit(originalLimit); - } - - return ret; - } - @Override public MessageWithHeader touch(Object o) { super.touch(o); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 87e212f3e157b..50b811604b84b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -67,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId, body()); + return Objects.hashCode(byteCount, streamId); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java new file mode 100644 index 0000000000000..fa1d26e76b852 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * An RPC with data that is sent outside of the frame, so it can be read as a stream. + */ +public final class UploadStream extends AbstractMessage implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + public final ManagedBuffer meta; + public final long bodyByteCount; + + public UploadStream(long requestId, ManagedBuffer meta, ManagedBuffer body) { + super(body, false); // body is *not* included in the frame + this.requestId = requestId; + this.meta = meta; + bodyByteCount = body.size(); + } + + // this version is called when decoding the bytes on the receiving end. The body is handled + // separately. + private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { + super(null, false); + this.requestId = requestId; + this.meta = meta; + this.bodyByteCount = bodyByteCount; + } + + @Override + public Type type() { return Type.UploadStream; } + + @Override + public int encodedLength() { + // the requestId, meta size, meta and bodyByteCount (body is not included) + return 8 + 4 + ((int) meta.size()) + 8; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + try { + ByteBuffer metaBuf = meta.nioByteBuffer(); + buf.writeInt(metaBuf.remaining()); + buf.writeBytes(metaBuf); + } catch (IOException io) { + throw new RuntimeException(io); + } + buf.writeLong(bodyByteCount); + } + + public static UploadStream decode(ByteBuf buf) { + long requestId = buf.readLong(); + int metaSize = buf.readInt(); + ManagedBuffer meta = new NettyManagedBuffer(buf.readRetainedSlice(metaSize)); + long bodyByteCount = buf.readLong(); + // This is called by the frame decoder, so the data is still null. We need a StreamInterceptor + // to read the data. + return new UploadStream(requestId, meta, bodyByteCount); + } + + @Override + public int hashCode() { + return Long.hashCode(requestId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UploadStream) { + UploadStream o = (UploadStream) other; + return requestId == o.requestId && super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("body", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 3ac9081d78a75..e1275689ae6a0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -135,13 +135,14 @@ static class EncryptedMessage extends AbstractFileRegion { private final boolean isByteBuf; private final ByteBuf buf; private final FileRegion region; + private final int maxOutboundBlockSize; /** * A channel used to buffer input data for encryption. The channel has an upper size bound * so that if the input is larger than the allowed buffer, it will be broken into multiple - * chunks. + * chunks. Made non-final to enable lazy initialization, which saves memory. */ - private final ByteArrayWritableChannel byteChannel; + private ByteArrayWritableChannel byteChannel; private ByteBuf currentHeader; private ByteBuffer currentChunk; @@ -157,7 +158,7 @@ static class EncryptedMessage extends AbstractFileRegion { this.isByteBuf = msg instanceof ByteBuf; this.buf = isByteBuf ? (ByteBuf) msg : null; this.region = isByteBuf ? null : (FileRegion) msg; - this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + this.maxOutboundBlockSize = maxOutboundBlockSize; } /** @@ -230,17 +231,17 @@ public boolean release(int decrement) { * data into memory at once, and can avoid ballooning memory usage when transferring large * messages such as shuffle blocks. * - * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward + * The {@link #transferred()} counter also behaves a little funny, in that it won't go forward * until a whole chunk has been written. This is done because the code can't use the actual * number of bytes written to the channel as the transferred count (see {@link #count()}). * Instead, once an encrypted chunk is written to the output (including its header), the - * size of the original block will be added to the {@link #transfered()} amount. + * size of the original block will be added to the {@link #transferred()} amount. */ @Override public long transferTo(final WritableByteChannel target, final long position) throws IOException { - Preconditions.checkArgument(position == transfered(), "Invalid position."); + Preconditions.checkArgument(position == transferred(), "Invalid position."); long reportedWritten = 0L; long actuallyWritten = 0L; @@ -272,7 +273,7 @@ public long transferTo(final WritableByteChannel target, final long position) currentChunkSize = 0; currentReportedBytes = 0; } - } while (currentChunk == null && transfered() + reportedWritten < count()); + } while (currentChunk == null && transferred() + reportedWritten < count()); // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead, // we return 1 until we can (i.e. until the reported count would actually match the size @@ -292,12 +293,15 @@ public long transferTo(final WritableByteChannel target, final long position) } private void nextChunk() throws IOException { + if (byteChannel == null) { + byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + } byteChannel.reset(); if (isByteBuf) { int copied = byteChannel.write(buf.nioBuffer()); buf.skipBytes(copied); } else { - region.transferTo(byteChannel, region.transfered()); + region.transferTo(byteChannel, region.transferred()); } byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 0231428318add..355a3def8cc22 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -132,6 +133,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 8f7554e2e07d5..38569baf82bce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; /** @@ -36,7 +37,8 @@ public abstract class RpcHandler { * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * - * This method will not be called in parallel for a single TransportClient (i.e., channel). + * Neither this method nor #receiveStream will be called in parallel for a single + * TransportClient (i.e., channel). * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. @@ -49,6 +51,36 @@ public abstract void receive( ByteBuffer message, RpcResponseCallback callback); + /** + * Receive a single RPC message which includes data that is to be received as a stream. Any + * exception thrown while in this method will be sent back to the client in string form as a + * standard RPC failure. + * + * Neither this method nor #receive will be called in parallel for a single TransportClient + * (i.e., channel). + * + * An error while reading data from the stream + * ({@link org.apache.spark.network.client.StreamCallback#onData(String, ByteBuffer)}) + * will fail the entire channel. A failure in "post-processing" the stream in + * {@link org.apache.spark.network.client.StreamCallback#onComplete(String)} will result in an + * rpcFailure, but the channel will remain active. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param messageHeader The serialized bytes of the header portion of the RPC. This is in meant + * to be relatively small, and will be buffered entirely in memory, to + * facilitate how the streaming portion should be received. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + * @return a StreamCallback for handling the accompanying streaming data + */ + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e94453578e6b0..9fac96dbe450d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -28,20 +29,10 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.util.TransportFrameDecoder; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -52,6 +43,7 @@ * The messages should have been processed by the pipeline setup by {@link TransportServer}. */ public class TransportRequestHandler extends MessageHandler { + private static final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); /** The Netty channel that this handler is associated with. */ @@ -113,6 +105,8 @@ public void handle(RequestMessage request) { processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); + } else if (request instanceof UploadStream) { + processStreamUpload((UploadStream) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -203,6 +197,79 @@ public void onFailure(Throwable e) { } } + /** + * Handle a request from the client to upload a stream of data. + */ + private void processStreamUpload(final UploadStream req) { + assert (req.body() == null); + try { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }; + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + ByteBuffer meta = req.meta.nioByteBuffer(); + StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + if (streamHandler == null) { + throw new NullPointerException("rpcHandler returned a null streamHandler"); + } + StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + streamHandler.onData(streamId, buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + try { + streamHandler.onComplete(streamId); + callback.onSuccess(ByteBuffer.allocate(0)); + } catch (Exception ex) { + IOException ioExc = new IOException("Failure post-processing complete stream;" + + " failing this rpc and leaving channel active", ex); + callback.onFailure(ioExc); + streamHandler.onFailure(streamId, ioExc); + } + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + callback.onFailure(new IOException("Destination failed while reading stream", cause)); + streamHandler.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamHandler.getID(); + } + }; + if (req.bodyByteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor<>( + this, wrappedCallback.getID(), req.bodyByteCount, wrappedCallback); + frameDecoder.setInterceptor(interceptor); + } else { + wrappedCallback.onComplete(wrappedCallback.getID()); + } + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + // We choose to totally fail the channel, rather than trying to recover as we do in other + // cases. We don't know how many bytes of the stream the client has already sent for the + // stream, it's not worth trying to recover. + channel.pipeline().fireExceptionCaught(e); + } finally { + req.meta.release(); + } + } + private void processOneWayMessage(OneWayMessage req) { try { rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0719fa7647bcc..9c85ab2f5f06f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -32,6 +32,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,11 +70,14 @@ public TransportServer( this.appRpcHandler = appRpcHandler; this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); + boolean shouldClose = true; try { init(hostToBind, portToBind); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(this); - throw e; + shouldClose = false; + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(this); + } } } @@ -98,6 +102,7 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) + .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS) .childOption(ChannelOption.ALLOCATOR, allocator); this.metrics = new NettyMemoryMetrics( @@ -146,11 +151,11 @@ public void close() { channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); channelFuture = null; } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully(); + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully(); } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully(); + if (bootstrap != null && bootstrap.config().childGroup() != null) { + bootstrap.config().childGroup().shutdownGracefully(); } bootstrap = null; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index afc59efaef810..b5497087634ce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,10 +17,7 @@ package org.apache.spark.network.util; -import java.io.Closeable; -import java.io.EOFException; -import java.io.File; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; @@ -91,11 +88,24 @@ public static String bytesToString(ByteBuffer b) { * @throws IOException if deletion is unsuccessful */ public static void deleteRecursively(File file) throws IOException { + deleteRecursively(file, null); + } + + /** + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * + * @param file Input file / dir to be deleted + * @param filter A filename filter that make sure only files / dirs with the satisfied filenames + * are deleted. + * @throws IOException if deletion is unsuccessful + */ + public static void deleteRecursively(File file, FilenameFilter filter) throws IOException { if (file == null) { return; } // On Unix systems, use operating system command to run faster // If that does not work out, fallback to the Java IO way - if (SystemUtils.IS_OS_UNIX) { + if (SystemUtils.IS_OS_UNIX && filter == null) { try { deleteRecursivelyUsingUnixNative(file); return; @@ -105,15 +115,17 @@ public static void deleteRecursively(File file) throws IOException { } } - deleteRecursivelyUsingJavaIO(file); + deleteRecursivelyUsingJavaIO(file, filter); } - private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { + private static void deleteRecursivelyUsingJavaIO( + File file, + FilenameFilter filter) throws IOException { if (file.isDirectory() && !isSymlink(file)) { IOException savedIOException = null; - for (File child : listFilesSafely(file)) { + for (File child : listFilesSafely(file, filter)) { try { - deleteRecursively(child); + deleteRecursively(child, filter); } catch (IOException e) { // In case of multiple exceptions, only last one will be thrown savedIOException = e; @@ -124,10 +136,13 @@ private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { } } - boolean deleted = file.delete(); - // Delete can also fail if the file simply did not exist. - if (!deleted && file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath()); + // Delete file only when it's a normal file or an empty directory. + if (file.isFile() || (file.isDirectory() && listFilesSafely(file, null).length == 0)) { + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } } } @@ -157,9 +172,9 @@ private static void deleteRecursivelyUsingUnixNative(File file) throws IOExcepti } } - private static File[] listFilesSafely(File file) throws IOException { + private static File[] listFilesSafely(File file, FilenameFilter filter) throws IOException { if (file.exists()) { - File[] files = file.listFiles(); + File[] files = file.listFiles(filter); if (files == null) { throw new IOException("Failed to list files for dir: " + file); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 5e85180bd6f9f..33d6eb4a83a0c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -17,7 +17,6 @@ package org.apache.spark.network.util; -import java.lang.reflect.Field; import java.util.concurrent.ThreadFactory; import io.netty.buffer.PooledByteBufAllocator; @@ -111,24 +110,14 @@ public static PooledByteBufAllocator createPooledByteBufAllocator( } return new PooledByteBufAllocator( allowDirectBufs && PlatformDependent.directBufferPreferred(), - Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores), - Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0, - allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0, - allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0 + Math.min(PooledByteBufAllocator.defaultNumHeapArena(), numCores), + Math.min(PooledByteBufAllocator.defaultNumDirectArena(), allowDirectBufs ? numCores : 0), + PooledByteBufAllocator.defaultPageSize(), + PooledByteBufAllocator.defaultMaxOrder(), + allowCache ? PooledByteBufAllocator.defaultTinyCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultSmallCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultNormalCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultUseCacheForAllThreads() : false ); } - - /** Used to get defaults from Netty's private static fields. */ - private static int getPrivateStaticField(String name) { - try { - Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); - f.setAccessible(true); - return f.getInt(null); - } catch (Exception e) { - throw new RuntimeException(e); - } - } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 91497b9492219..34e4bb5912dcb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -209,7 +209,7 @@ public String keyFactoryAlgorithm() { * (128 bits by default), which is not generally the case with user passwords. */ public int keyFactoryIterations() { - return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024); + return conf.getInt("spark.network.crypto.keyFactoryIterations", 1024); } /** diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index bc94f7ca63a96..6fb44fea8c5a4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -116,8 +116,8 @@ public void encode(ChannelHandlerContext ctx, FileRegion in, List out) throws Exception { ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); - while (in.transfered() < in.count()) { - in.transferTo(channel, in.transfered()); + while (in.transferred() < in.count()) { + in.transferTo(channel, in.transferred()); } out.add(Unpooled.wrappedBuffer(channel.getData())); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8ff737b129641..1f4d75c7e2ec5 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,43 +17,46 @@ package org.apache.spark.network; +import java.io.*; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Set; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; +import com.google.common.io.Files; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.*; +import org.apache.spark.network.server.*; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { + static TransportConf conf; static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; static List oneWayMsgs; + static StreamTestHelper testData; + + static ConcurrentHashMap streamCallbacks = + new ConcurrentHashMap<>(); @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + testData = new StreamTestHelper(); rpcHandler = new RpcHandler() { @Override public void receive( @@ -71,6 +74,14 @@ public void receive( } } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + return receiveStreamHelper(JavaUtils.bytesToString(messageHeader)); + } + @Override public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs.add(JavaUtils.bytesToString(message)); @@ -85,10 +96,71 @@ public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs = new ArrayList<>(); } + private static StreamCallbackWithID receiveStreamHelper(String msg) { + try { + if (msg.startsWith("fail/")) { + String[] parts = msg.split("/"); + switch (parts[1]) { + case "exception-ondata": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + throw new IOException("failed to read stream data!"); + } + + @Override + public void onComplete(String streamId) throws IOException { + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "exception-oncomplete": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + } + + @Override + public void onComplete(String streamId) throws IOException { + throw new IOException("exception in onComplete"); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "null": + return null; + default: + throw new IllegalArgumentException("unexpected msg: " + msg); + } + } else { + VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); + streamCallbacks.put(msg, streamCallback); + return streamCallback; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @AfterClass public static void tearDown() { server.close(); clientFactory.close(); + testData.cleanup(); } static class RpcResult { @@ -130,6 +202,59 @@ public void onFailure(Throwable e) { return res; } + private RpcResult sendRpcWithStream(String... streams) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + + for (String stream : streams) { + int idx = stream.lastIndexOf('/'); + ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); + String streamName = (idx == -1) ? stream : stream.substring(idx + 1); + ManagedBuffer data = testData.openStream(conf, streamName); + client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); + } + + if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + streamCallbacks.values().forEach(streamCallback -> { + try { + streamCallback.verify(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + client.close(); + return res; + } + + private static class RpcStreamCallback implements RpcResponseCallback { + final String streamId; + final RpcResult res; + final Semaphore sem; + + RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) { + this.streamId = streamId; + this.res = res; + this.sem = sem; + } + + @Override + public void onSuccess(ByteBuffer message) { + res.successMessages.add(streamId); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + } + @Test public void singleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron"); @@ -193,10 +318,83 @@ public void sendOneWayMessage() throws Exception { } } + @Test + public void sendRpcWithStreamOneAtATime() throws Exception { + for (String stream : StreamTestHelper.STREAMS) { + RpcResult res = sendRpcWithStream(stream); + assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty()); + assertEquals(Sets.newHashSet(stream), res.successMessages); + } + } + + @Test + public void sendRpcWithStreamConcurrently() throws Exception { + String[] streams = new String[10]; + for (int i = 0; i < 10; i++) { + streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length]; + } + RpcResult res = sendRpcWithStream(streams); + assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void sendRpcWithStreamFailures() throws Exception { + // when there is a failure reading stream data, we don't try to keep the channel usable, + // just send back a decent error msg. + RpcResult exceptionInCallbackResult = + sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + RpcResult nullStreamHandler = + sendRpcWithStream("fail/null/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + // OTOH, if there is a failure during onComplete, the channel should still be fine + RpcResult exceptionInOnComplete = + sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer"); + assertErrorsContain(exceptionInOnComplete.errorMessages, + Sets.newHashSet("Failure post-processing")); + assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages); + } + private void assertErrorsContain(Set errors, Set contains) { - assertEquals(contains.size(), errors.size()); + assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + + errors, contains.size(), errors.size()); + + Pair, Set> r = checkErrorsContain(errors, contains); + assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors, + r.getRight().isEmpty()); + + assertTrue(r.getLeft().isEmpty()); + } + + private void assertErrorAndClosed(RpcResult result, String expectedError) { + assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); + // we expect 1 additional error, which contains *either* "closed" or "Connection reset" + Set errors = result.errorMessages; + assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + + errors, 2, errors.size()); + + Set containsAndClosed = Sets.newHashSet(expectedError); + containsAndClosed.add("closed"); + containsAndClosed.add("Connection reset"); + + Pair, Set> r = checkErrorsContain(errors, containsAndClosed); + Set errorsNotFound = r.getRight(); + assertEquals(1, errorsNotFound.size()); + String err = errorsNotFound.iterator().next(); + assertTrue(err.equals("closed") || err.equals("Connection reset")); + + assertTrue(r.getLeft().isEmpty()); + } + + private Pair, Set> checkErrorsContain( + Set errors, + Set contains) { Set remainingErrors = Sets.newHashSet(errors); + Set notFound = Sets.newHashSet(); for (String contain : contains) { Iterator it = remainingErrors.iterator(); boolean foundMatch = false; @@ -207,9 +405,66 @@ private void assertErrorsContain(Set errors, Set contains) { break; } } - assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + if (!foundMatch) { + notFound.add(contain); + } + } + return new ImmutablePair<>(remainingErrors, notFound); + } + + private static class VerifyingStreamCallback implements StreamCallbackWithID { + final String streamId; + final StreamSuite.TestCallback helper; + final OutputStream out; + final File outFile; + + VerifyingStreamCallback(String streamId) throws IOException { + if (streamId.equals("file")) { + outFile = File.createTempFile("data", ".tmp", testData.tempDir); + out = new FileOutputStream(outFile); + } else { + out = new ByteArrayOutputStream(); + outFile = null; + } + this.streamId = streamId; + helper = new StreamSuite.TestCallback(out); + } + + void verify() throws IOException { + if (streamId.equals("file")) { + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); + } else { + byte[] result = ((ByteArrayOutputStream)out).toByteArray(); + ByteBuffer srcBuffer = testData.srcBuffer(streamId); + ByteBuffer base; + synchronized (srcBuffer) { + base = srcBuffer.duplicate(); + } + byte[] expected = new byte[base.remaining()]; + base.get(expected); + assertEquals(expected.length, result.length); + assertTrue("buffers don't match", Arrays.equals(expected, result)); + } + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + helper.onData(streamId, buf); } - assertTrue(remainingErrors.isEmpty()); + @Override + public void onComplete(String streamId) throws IOException { + helper.onComplete(streamId); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + helper.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamId; + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..f3050cb79cdfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -26,7 +26,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Random; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -37,9 +36,7 @@ import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; @@ -51,16 +48,11 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + private static final String[] STREAMS = StreamTestHelper.STREAMS; + private static StreamTestHelper testData; private static TransportServer server; private static TransportClientFactory clientFactory; - private static File testFile; - private static File tempDir; - - private static ByteBuffer emptyBuffer; - private static ByteBuffer smallBuffer; - private static ByteBuffer largeBuffer; private static ByteBuffer createBuffer(int bufSize) { ByteBuffer buf = ByteBuffer.allocate(bufSize); @@ -73,23 +65,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { - tempDir = Files.createTempDir(); - emptyBuffer = createBuffer(0); - smallBuffer = createBuffer(100); - largeBuffer = createBuffer(100000); - - testFile = File.createTempFile("stream-test-file", "txt", tempDir); - FileOutputStream fp = new FileOutputStream(testFile); - try { - Random rnd = new Random(); - for (int i = 0; i < 512; i++) { - byte[] fileContent = new byte[1024]; - rnd.nextBytes(fileContent); - fp.write(fileContent); - } - } finally { - fp.close(); - } + testData = new StreamTestHelper(); final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @@ -100,18 +76,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { @Override public ManagedBuffer openStream(String streamId) { - switch (streamId) { - case "largeBuffer": - return new NioManagedBuffer(largeBuffer); - case "smallBuffer": - return new NioManagedBuffer(smallBuffer); - case "emptyBuffer": - return new NioManagedBuffer(emptyBuffer); - case "file": - return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); - default: - throw new IllegalArgumentException("Invalid stream: " + streamId); - } + return testData.openStream(conf, streamId); } }; RpcHandler handler = new RpcHandler() { @@ -137,12 +102,7 @@ public StreamManager getStreamManager() { public static void tearDown() { server.close(); clientFactory.close(); - if (tempDir != null) { - for (File f : tempDir.listFiles()) { - f.delete(); - } - tempDir.delete(); - } + testData.cleanup(); } @Test @@ -234,21 +194,21 @@ public void run() { case "largeBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = largeBuffer; + srcBuffer = testData.largeBuffer; break; case "smallBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = smallBuffer; + srcBuffer = testData.smallBuffer; break; case "file": - outFile = File.createTempFile("data", ".tmp", tempDir); + outFile = File.createTempFile("data", ".tmp", testData.tempDir); out = new FileOutputStream(outFile); break; case "emptyBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = emptyBuffer; + srcBuffer = testData.emptyBuffer; break; default: throw new IllegalArgumentException(streamId); @@ -256,10 +216,10 @@ public void run() { TestCallback callback = new TestCallback(out); client.stream(streamId, callback); - waitForCompletion(callback); + callback.waitForCompletion(timeoutMs); if (srcBuffer == null) { - assertTrue("File stream did not match.", Files.equal(testFile, outFile)); + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); } else { ByteBuffer base; synchronized (srcBuffer) { @@ -292,23 +252,9 @@ public void check() throws Throwable { throw error; } } - - private void waitForCompletion(TestCallback callback) throws Exception { - long now = System.currentTimeMillis(); - long deadline = now + timeoutMs; - synchronized (callback) { - while (!callback.completed && now < deadline) { - callback.wait(deadline - now); - now = System.currentTimeMillis(); - } - } - assertTrue("Timed out waiting for stream.", callback.completed); - assertNull(callback.error); - } - } - private static class TestCallback implements StreamCallback { + static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; @@ -344,6 +290,22 @@ public void onFailure(String streamId, Throwable cause) { } } + void waitForCompletion(long timeoutMs) { + long now = System.currentTimeMillis(); + long deadline = now + timeoutMs; + synchronized (this) { + while (!completed && now < deadline) { + try { + wait(deadline - now); + } catch (InterruptedException ie) { + throw new RuntimeException(ie); + } + now = System.currentTimeMillis(); + } + } + assertTrue("Timed out waiting for stream.", completed); + assertNull(error); + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java new file mode 100644 index 0000000000000..0f5c82c9e9b1f --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Random; + +import com.google.common.io.Files; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +class StreamTestHelper { + static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + + final File testFile; + final File tempDir; + + final ByteBuffer emptyBuffer; + final ByteBuffer smallBuffer; + final ByteBuffer largeBuffer; + + private static ByteBuffer createBuffer(int bufSize) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + return buf; + } + + StreamTestHelper() throws Exception { + tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); + smallBuffer = createBuffer(100); + largeBuffer = createBuffer(100000); + + testFile = File.createTempFile("stream-test-file", "txt", tempDir); + FileOutputStream fp = new FileOutputStream(testFile); + try { + Random rnd = new Random(); + for (int i = 0; i < 512; i++) { + byte[] fileContent = new byte[1024]; + rnd.nextBytes(fileContent); + fp.write(fileContent); + } + } finally { + fp.close(); + } + } + + public ByteBuffer srcBuffer(String name) { + switch (name) { + case "largeBuffer": + return largeBuffer; + case "smallBuffer": + return smallBuffer; + case "emptyBuffer": + return emptyBuffer; + default: + throw new IllegalArgumentException("Invalid stream: " + name); + } + } + + public ManagedBuffer openStream(TransportConf conf, String streamId) { + switch (streamId) { + case "file": + return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); + default: + return new NioManagedBuffer(srcBuffer(streamId)); + } + } + + void cleanup() { + if (tempDir != null) { + try { + JavaUtils.deleteRecursively(tempDir); + } catch (IOException io) { + throw new RuntimeException(io); + } + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index ecb66fcf2ff76..3bff34e210e3c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -22,6 +22,7 @@ import java.nio.channels.WritableByteChannel; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import org.apache.spark.network.util.AbstractFileRegion; import org.junit.Test; @@ -48,7 +49,36 @@ public void testShortWrite() throws Exception { @Test public void testByteBufBody() throws Exception { + testByteBufBody(Unpooled.copyLong(42)); + } + + @Test + public void testCompositeByteBufBodySingleBuffer() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, header); + assertEquals(1, compositeByteBuf.nioBufferCount()); + testByteBufBody(compositeByteBuf); + } + + @Test + public void testCompositeByteBufBodyMultipleBuffers() throws Exception { ByteBuf header = Unpooled.copyLong(42); + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, header.retainedSlice(0, 4)); + compositeByteBuf.addComponent(true, header.slice(4, 4)); + assertEquals(2, compositeByteBuf.nioBufferCount()); + testByteBufBody(compositeByteBuf); + } + + /** + * Test writing a {@link MessageWithHeader} using the given {@link ByteBuf} as header. + * + * @param header the header to use. + * @throws Exception thrown on error. + */ + private void testByteBufBody(ByteBuf header) throws Exception { + long expectedHeaderValue = header.getLong(header.readerIndex()); ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); assertEquals(1, header.refCnt()); assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); @@ -61,7 +91,7 @@ public void testByteBufBody() throws Exception { MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); ByteBuf result = doWrite(msg, 1); assertEquals(msg.count(), result.readableBytes()); - assertEquals(42, result.readLong()); + assertEquals(expectedHeaderValue, result.readLong()); assertEquals(84, result.readLong()); assertTrue(msg.release()); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index fc7bba41185f0..098fa7974b87b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -138,6 +138,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + /** + * Clean up any non-shuffle files in any local directories associated with an finished executor. + */ + public void executorRemoved(String executorId, String appId) { + blockManager.executorRemoved(executorId, appId); + } + /** * Register an (application, executor) with the given shuffle info. * diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e6399897be9c2..0b7a27402369d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -24,6 +24,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -59,6 +61,7 @@ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); private static final ObjectMapper mapper = new ObjectMapper(); + /** * This a common prefix to the key for each app registration we stick in leveldb, so they * are easy to find, since leveldb lets you search based on prefix. @@ -66,6 +69,8 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); + // Map containing all registered executors' metadata. @VisibleForTesting final ConcurrentMap executors; @@ -211,6 +216,26 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { } } + /** + * Removes all the non-shuffle files in any local directories associated with the finished + * executor. + */ + public void executorRemoved(String executorId, String appId) { + logger.info("Clean up non-shuffle files associated with the finished executor {}", executorId); + AppExecId fullId = new AppExecId(appId, executorId); + final ExecutorShuffleInfo executor = executors.get(fullId); + if (executor == null) { + // Executor not registered, skip clean up of the local directories. + logger.info("Executor is not registered (appId={}, execId={})", appId, executorId); + } else { + logger.info("Cleaning up non-shuffle files in executor {}'s {} local dirs", fullId, + executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(() -> deleteNonShuffleFiles(executor.localDirs)); + } + } + /** * Synchronously deletes each directory one at a time. * Should be executed in its own thread, as this may take a long time. @@ -226,6 +251,29 @@ private void deleteExecutorDirs(String[] dirs) { } } + /** + * Synchronously deletes non-shuffle files in each directory recursively. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteNonShuffleFiles(String[] dirs) { + FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir), filter); + logger.debug("Successfully cleaned up non-shuffle files in directory: {}", localDir); + } catch (Exception e) { + logger.error("Failed to delete non-shuffle files in directory: " + localDir, e); + } + } + } + /** * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, @@ -259,7 +307,8 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; - return new File(new File(localDir, String.format("%02x", subDirId)), filename); + return new File(createNormalizedInternedPathname( + localDir, String.format("%02x", subDirId), filename)); } void close() { @@ -272,6 +321,28 @@ void close() { } } + /** + * This method is needed to avoid the situation when multiple File instances for the + * same pathname "foo/bar" are created, each with a separate copy of the "foo/bar" String. + * According to measurements, in some scenarios such duplicate strings may waste a lot + * of memory (~ 10% of the heap). To avoid that, we intern the pathname, and before that + * we make sure that it's in a normalized form (contains no "//", "///" etc.) Otherwise, + * the internal code in java.io.File would normalize it later, creating a new "foo/bar" + * String copy. Unfortunately, we cannot just reuse the normalization code that java.io.File + * uses, since it is in the package-private class java.io.FileSystem. + */ + @VisibleForTesting + static String createNormalizedInternedPathname(String dir1, String dir2, String fname) { + String pathname = dir1 + File.separator + dir2 + File.separator + fname; + Matcher m = MULTIPLE_SEPARATORS.matcher(pathname); + pathname = m.replaceAll("/"); + // A single trailing slash needs to be taken care of separately + if (pathname.length() > 1 && pathname.endsWith("/")) { + pathname = pathname.substring(0, pathname.length() - 1); + } + return pathname.intern(); + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 9af6759f5d5f3..a68a297519b66 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -42,7 +42,7 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6); private final byte id; @@ -67,6 +67,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); case 5: return ShuffleServiceHeartbeat.decode(buf); + case 6: return UploadBlockStream.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java new file mode 100644 index 0000000000000..9df30967d5bb2 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * A request to Upload a block, which the destination should receive as a stream. + * + * The actual block data is not contained here. It will be passed to the StreamCallbackWithID + * that is returned from RpcHandler.receiveStream() + */ +public class UploadBlockStream extends BlockTransferMessage { + public final String blockId; + public final byte[] metadata; + + public UploadBlockStream(String blockId, byte[] metadata) { + this.blockId = blockId; + this.metadata = metadata; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK_STREAM; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(blockId); + return objectsHashCode * 41 + Arrays.hashCode(metadata); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlockStream) { + UploadBlockStream o = (UploadBlockStream) other; + return Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + } + + public static UploadBlockStream decode(ByteBuf buf) { + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + return new UploadBlockStream(blockId, metadata); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 6d201b8fe8d7d..d2072a54fa415 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -135,4 +136,23 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { "\"subDirsPerLocalDir\": 7, \"shuffleManager\": " + "\"" + SORT_MANAGER + "\"}"; assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); } + + @Test + public void testNormalizeAndInternPathname() { + assertPathsMatch("/foo", "bar", "baz", "/foo/bar/baz"); + assertPathsMatch("//foo/", "bar/", "//baz", "/foo/bar/baz"); + assertPathsMatch("foo", "bar", "baz///", "foo/bar/baz"); + assertPathsMatch("/foo/", "/bar//", "/baz", "/foo/bar/baz"); + assertPathsMatch("/", "", "", "/"); + assertPathsMatch("/", "/", "/", "/"); + } + + private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) { + String normPathname = + ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3); + assertEquals(expectedPathname, normPathname); + File file = new File(normPathname); + String returnedPath = file.getPath(); + assertTrue(normPathname == returnedPath); + } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java new file mode 100644 index 0000000000000..d22f3ace4103b --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class NonShuffleFilesCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + private TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + + @Test + public void cleanupOnRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(true); + } + + @Test + public void cleanupOnRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(false); + } + + private void cleanupOnRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + resolver.executorRemoved("exec0", "app"); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutorWithShuffleFiles() throws IOException { + cleanupUsesExecutor(true); + } + + @Test + public void cleanupUsesExecutorWithoutShuffleFiles() throws IOException { + cleanupUsesExecutor(false); + } + + private void cleanupUsesExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = runnable -> cleanupCalled.set(true); + + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + manager.executorRemoved("exec0", "app"); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + } + + @Test + public void cleanupOnlyRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(true); + } + + @Test + public void cleanupOnlyRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(false); + } + + private void cleanupOnlyRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext0 = initDataContext(withShuffleFiles); + TestShuffleDataContext dataContext1 = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo(SORT_MANAGER)); + resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo(SORT_MANAGER)); + + + resolver.executorRemoved("exec-nonexistent", "app"); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(true); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(false); + } + + private void cleanupOnlyRegisteredExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + + resolver.executorRemoved("exec1", "app"); + assertStillThere(dataContext); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext); + } + + private static void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private static FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + private static boolean assertOnlyShuffleDataInDir(File[] dirs) { + for (File dir : dirs) { + assertTrue(dir.getName() + " wasn't cleaned up", !dir.exists() || + dir.listFiles(filter).length == 0 || assertOnlyShuffleDataInDir(dir.listFiles())); + } + return true; + } + + private static void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + File[] dirs = new File[] {new File(localDir)}; + assertOnlyShuffleDataInDir(dirs); + } + } + + private static TestShuffleDataContext initDataContext(boolean withShuffleFiles) + throws IOException { + if (withShuffleFiles) { + return initDataContextWithShuffleFiles(); + } else { + return initDataContextWithoutShuffleFiles(); + } + } + + private static TestShuffleDataContext initDataContextWithShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createShuffleFiles(dataContext); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext initDataContextWithoutShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext createDataContext() { + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + dataContext.create(); + return dataContext; + } + + private static void createShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + Random rand = new Random(123); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { + "ABC".getBytes(StandardCharsets.UTF_8), + "DEF".getBytes(StandardCharsets.UTF_8)}); + } + + private static void createNonShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + // Create spill file(s) + dataContext.insertSpillData(); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 81e01949e50fa..6989c3baf2e28 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -22,6 +22,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.util.UUID; import com.google.common.io.Closeables; import com.google.common.io.Files; @@ -94,6 +95,20 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr } } + /** Creates spill file(s) within the local dirs. */ + public void insertSpillData() throws IOException { + String filename = "temp_local_" + UUID.randomUUID(); + OutputStream dataStream = null; + + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, filename)); + dataStream.write(42); + } finally { + Closeables.close(dataStream, false); + } + } + /** * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this * context's directories. diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index c03caf0076f61..ecd7c19f2c634 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,10 +17,12 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.Platform; - import java.util.Arrays; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + public final class ByteArray { public static final byte[] EMPTY_BYTE = new byte[0]; @@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { public static byte[] concat(byte[]... inputs) { // Compute the total length of the result - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].length; + totalLength += (long)inputs[i].length; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].length; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e9b3d9b045af5..e91fc4391425c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -29,8 +29,8 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; - import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -877,17 +877,17 @@ public UTF8String lpad(int len, UTF8String pad) { */ public static UTF8String concat(UTF8String... inputs) { // Compute the total length of the result. - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].numBytes; + totalLength += (long)inputs[i].numBytes; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it. - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index d7ed005db1891..d9898771720ae 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -77,15 +77,15 @@ public void testKnownWordsInputs() { for (int i = 0; i < 16; i++) { bytes[i] = 0; } - Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + Assert.assertEquals(-300363099, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); for (int i = 0; i < 16; i++) { bytes[i] = -1; } - Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + Assert.assertEquals(-1210324667, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); for (int i = 0; i < 16; i++) { bytes[i] = (byte)i; } - Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + Assert.assertEquals(-634919701, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); } @Test diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 5d5fdc1c55a75..ef5ff8ee70ec0 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -120,6 +120,8 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } catch (Exception expected) { Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); } + + memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); } @Test @@ -165,11 +167,13 @@ public void testOffHeapArrayMemoryBlock() { int length = 56; check(memory, obj, offset, length); + memoryAllocator.free(memory); long address = Platform.allocateMemory(112); memory = new OffHeapMemoryBlock(address, length); obj = memory.getBaseObject(); offset = memory.getBaseOffset(); check(memory, obj, offset, length); + Platform.freeMemory(address); } } diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 48004e812a8bf..7d3331f44f015 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -192,8 +192,8 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString)) test("concat") { - def concat(orgin: Seq[String]): String = - if (orgin.contains(null)) null else orgin.mkString + def concat(origin: Seq[String]): String = + if (origin.contains(null)) null else origin.mkString forAll { (inputs: Seq[String]) => assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString)) diff --git a/core/pom.xml b/core/pom.xml index 9258a856028a0..5fa3a86de6b01 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -56,7 +56,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.apache.hadoop @@ -88,13 +88,19 @@ ${project.version} - net.java.dev.jets3t - jets3t + javax.activation + activation org.apache.curator curator-recipes + + + org.apache.zookeeper + zookeeper + @@ -344,7 +350,7 @@ net.sf.py4j py4j - 0.10.6 + 0.10.7 org.apache.spark diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c52831..e3bd5496cf5ba 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,7 +167,8 @@ public void write(Iterator> records) throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c3a07b2abf896..c7d2db4217d96 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -43,6 +43,7 @@ import org.apache.spark.storage.FileSegment; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; @@ -184,6 +185,7 @@ private void writeSortedFile(boolean isLastFile) { blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); int currentPartition = -1; + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); while (sortedRecords.hasNext()) { sortedRecords.loadNext(); final int partition = sortedRecords.packedRecordPointer.getPartitionId(); @@ -200,8 +202,8 @@ private void writeSortedFile(boolean isLastFile) { final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); final Object recordPage = taskMemoryManager.getPage(recordPointer); final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); - long recordReadPosition = recordOffsetInPage + 4; // skip over record length + int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); Platform.copyMemory( @@ -389,15 +391,16 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p } growPointerArrayIfNecessary(); - // Need 4 bytes to store the record length. - final int required = length + 4; + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + // Need 4 or 8 bytes to store the record length. + final int required = length + uaoSize; acquireNewPageIfNecessary(required); assert(currentPage != null); final Object base = currentPage.getBaseObject(); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - Platform.putInt(base, pageCursor, length); - pageCursor += 4; + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 8f49859746b89..4b48599ad311e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -65,7 +65,7 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int usableCapacity = 0; - private int initialSize; + private final int initialSize; ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) { this.consumer = consumer; @@ -94,12 +94,20 @@ public int numRecords() { } public void reset() { + // Reset `pos` here so that `spill` triggered by the below `allocateArray` will be no-op. + pos = 0; if (consumer != null) { consumer.freeArray(array); + // As `array` has been released, we should set it to `null` to avoid accessing it before + // `allocateArray` returns. `usableCapacity` is also set to `0` to avoid any codes writing + // data to `ShuffleInMemorySorter` when `array` is `null` (e.g., in + // ShuffleExternalSorter.growPointerArrayIfNecessary, we may try to access + // `ShuffleInMemorySorter` when `allocateArray` throws SparkOutOfMemoryError). + array = null; + usableCapacity = 0; array = consumer.allocateArray(initialSize); usableCapacity = getUsableCapacity(); } - pos = 0; } public void expandPointerArray(LongArray newArray) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..069e6d5f224d7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -248,7 +248,8 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f0045507aaab..9b6cbab38cbcc 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -662,7 +662,7 @@ public int getValueLength() { * It is only valid to call this method immediately after calling `lookup()` using the same key. *

*

- * The key and value must be word-aligned (that is, their sizes must multiples of 8). + * The key and value must be word-aligned (that is, their sizes must be a multiple of 8). *

*

* After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` @@ -703,7 +703,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final long recordLength = (2 * uaoSize) + klen + vlen + 8; + final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { if (!acquireNewPage(recordLength + uaoSize)) { return false; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 4fc19b1721518..399251b80e649 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -402,7 +402,7 @@ public void insertRecord( growPointerArrayIfNecessary(); int uaoSize = UnsafeAlignedOffset.getUaoSize(); - // Need 4 bytes to store the record length. + // Need 4 or 8 bytes to store the record length. final int required = length + uaoSize; acquireNewPageIfNecessary(required); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index ff0dcc259a4ad..ab800288dcb43 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -51,7 +51,7 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept if (spillReader.hasNext()) { // We only add the spillReader to the priorityQueue if it is not empty. We do this to // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator - // does not return wrong result because hasNext will returns true + // does not return wrong result because hasNext will return true // at least priorityQueue.size() times. If we allow n spillReaders in the // priorityQueue, we will have n extra empty records in the result of UnsafeSorterIterator. spillReader.loadNext(); diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala new file mode 100644 index 0000000000000..5e546c694e8d9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{Timer, TimerTask} +import java.util.concurrent.ConcurrentHashMap +import java.util.function.{Consumer, Function} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} + +/** + * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus + * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is + * from. + */ +private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { + override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" +} + +/** + * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync + * request is generated by `BarrierTaskContext.barrier()`, and identified by + * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon + * all the requests for a group of `barrier()` calls are received. If the coordinator is unable to + * collect enough global sync requests within a configured time, fail all the requests and return + * an Exception with timeout message. + */ +private[spark] class BarrierCoordinator( + timeoutInSecs: Long, + listenerBus: LiveListenerBus, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to + // fetch result, we shall fix the issue. + private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + + // Listen to StageCompleted event, clear corresponding ContextBarrierState. + private val listener = new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageInfo = stageCompleted.stageInfo + val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber) + // Clear ContextBarrierState from a finished stage attempt. + cleanupBarrierStage(barrierId) + } + } + + // Record all active stage attempts that make barrier() call(s), and the corresponding internal + // state. + private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] + + override def onStart(): Unit = { + super.onStart() + listenerBus.addToStatusQueue(listener) + } + + override def onStop(): Unit = { + try { + states.forEachValue(1, clearStateConsumer) + states.clear() + listenerBus.removeListener(listener) + } finally { + super.onStop() + } + } + + /** + * Provide the current state of a barrier() call. A state is created when a new stage attempt + * sends out a barrier() call, and recycled on stage completed. + * + * @param barrierId Identifier of the barrier stage that make a barrier() call. + * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall + * collect `numTasks` requests to succeed. + */ + private class ContextBarrierState( + val barrierId: ContextBarrierId, + val numTasks: Int) { + + // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used + // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or + // reset when a barrier() call fails due to timeout. + private var barrierEpoch: Int = 0 + + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() + // call. + private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + + // A timer task that ensures we may timeout for a barrier() call. + private var timerTask: TimerTask = null + + // Init a TimerTask for a barrier() call. + private def initTimerTask(): Unit = { + timerTask = new TimerTask { + override def run(): Unit = synchronized { + // Timeout current barrier() call, fail all the sync requests. + requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + + s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + + s"$timeoutInSecs second(s)."))) + cleanupBarrierStage(barrierId) + } + } + } + + // Cancel the current active TimerTask and release resources. + private def cancelTimerTask(): Unit = { + if (timerTask != null) { + timerTask.cancel() + timerTask = null + } + } + + // Process the global sync request. The barrier() call succeed if collected enough requests + // within a configured time, otherwise fail all the pending requests. + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + val taskId = request.taskAttemptId + val epoch = request.barrierEpoch + + // Require the number of tasks is correctly set from the BarrierTaskContext. + require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + + s"${request.numTasks} from Task $taskId, previously it was $numTasks.") + + // Check whether the epoch from the barrier tasks matches current barrierEpoch. + logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") + if (epoch != barrierEpoch) { + requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + + s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + + "properly killed.")) + } else { + // If this is the first sync message received for a barrier() call, start timer to ensure + // we may timeout for the sync. + if (requesters.isEmpty) { + initTimerTask() + timer.schedule(timerTask, timeoutInSecs * 1000) + } + // Add the requester to array of RPCCallContexts pending for reply. + requesters += requester + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + + s"$taskId, current progress: ${requesters.size}/$numTasks.") + if (maybeFinishAllRequesters(requesters, numTasks)) { + // Finished current barrier() call successfully, clean up ContextBarrierState and + // increase the barrier epoch. + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + + s"tasks, finished successfully.") + barrierEpoch += 1 + requesters.clear() + cancelTimerTask() + } + } + } + + // Finish all the blocking barrier sync requests from a stage attempt successfully if we + // have received all the sync requests. + private def maybeFinishAllRequesters( + requesters: ArrayBuffer[RpcCallContext], + numTasks: Int): Boolean = { + if (requesters.size == numTasks) { + requesters.foreach(_.reply(())) + true + } else { + false + } + } + + // Cleanup the internal state of a barrier stage attempt. + def clear(): Unit = synchronized { + // The global sync fails so the stage is expected to retry another attempt, all sync + // messages come from current stage attempt shall fail. + barrierEpoch = -1 + requesters.clear() + cancelTimerTask() + } + } + + // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. + private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { + val barrierState = states.remove(barrierId) + if (barrierState != null) { + barrierState.clear() + } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => + // Get or init the ContextBarrierState correspond to the stage attempt. + val barrierId = ContextBarrierId(stageId, stageAttemptId) + states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] { + override def apply(key: ContextBarrierId): ContextBarrierState = + new ContextBarrierState(key, numTasks) + }) + val barrierState = states.get(barrierId) + + barrierState.handleRequest(context, request) + } + + private val clearStateConsumer = new Consumer[ContextBarrierState] { + override def accept(state: ContextBarrierState) = state.clear() + } +} + +private[spark] sealed trait BarrierCoordinatorMessage extends Serializable + +/** + * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. + */ +private[spark] case class RequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala new file mode 100644 index 0000000000000..de827987f28f9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{Properties, Timer, TimerTask} + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} +import org.apache.spark.util.{RpcUtils, Utils} + +/** A [[TaskContext]] with extra info and tooling for a barrier stage. */ +class BarrierTaskContext( + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, + override val taskAttemptId: Long, + override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, + @transient private val metricsSystem: MetricsSystem, + // The default value is only used in tests. + override val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, + taskMemoryManager, localProperties, metricsSystem, taskMetrics) { + + // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls. + private val barrierCoordinator: RpcEndpointRef = { + val env = SparkEnv.get + RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) + } + + private val timer = new Timer("Barrier task timer for barrier() calls.") + + // Local barrierEpoch that identify a barrier() call from current task, it shall be identical + // with the driver side epoch. + private var barrierEpoch = 0 + + // Number of tasks of the current barrier stage, a barrier() call must collect enough requests + // from different tasks within the same barrier stage attempt to succeed. + private lazy val numTasks = getTaskInfos().size + + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of misuses listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + + s"the global sync, current barrier epoch is $barrierEpoch.") + logTrace("Current callSite: " + Utils.getCallSite()) + + val startTime = System.currentTimeMillis() + val timerTask = new TimerTask { + override def run(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + + s"under the global sync since $startTime, has been waiting for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + } + } + // Log the update of global sync every 60 seconds. + timer.schedule(timerTask, 60000, 60000) + + try { + barrierCoordinator.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), + // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by + // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. + timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout")) + barrierEpoch += 1 + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + + "global sync successfully, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " + + s"$barrierEpoch.") + } catch { + case e: SparkException => + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + + "to perform global sync, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + throw e + } finally { + timerTask.cancel() + } + } + + /** + * :: Experimental :: + * Returns the all task infos in this barrier stage, the task infos are ordered by partitionId. + */ + @Experimental + @Since("2.4.0") + def getTaskInfos(): Array[BarrierTaskInfo] = { + val addressesStr = localProperties.getProperty("addresses", "") + addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) + } +} + +object BarrierTaskContext { + /** + * Return the currently active BarrierTaskContext. This can be called inside of user functions to + * access contextual information about running barrier tasks. + */ + def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext] +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala new file mode 100644 index 0000000000000..ce2653df2e845 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.{Experimental, Since} + + +/** + * :: Experimental :: + * Carries all task infos of a barrier task. + * + * @param address the IPv4 address(host:port) of the executor that a barrier task is running on + */ +@Experimental +@Since("2.4.0") +class BarrierTaskInfo(val address: String) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 189d91333c045..17b88631bcb4c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} +import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMaster @@ -69,6 +69,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors * spark.dynamicAllocation.initialExecutors - Number of executors to start with * + * spark.dynamicAllocation.executorAllocationRatio - + * This is used to reduce the parallelism of the dynamic allocation that can waste + * resources when tasks are small + * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors * @@ -116,9 +120,12 @@ private[spark] class ExecutorAllocationManager( // TODO: The default value of 1 for spark.executor.cores works right now because dynamic // allocation is only supported for YARN and the default number of cores per executor in YARN is // 1, but it might need to be attained differently for different cluster managers - private val tasksPerExecutor = + private val tasksPerExecutorForFullParallelism = conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + private val executorAllocationRatio = + conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) + validateSettings() // Number of executors to add in the next round @@ -209,8 +216,13 @@ private[spark] class ExecutorAllocationManager( throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } - if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + if (tasksPerExecutorForFullParallelism == 0) { + throw new SparkException("spark.executor.cores must not be < spark.task.cpus.") + } + + if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { + throw new SparkException( + "spark.dynamicAllocation.executorAllocationRatio must be > 0 and <= 1.0") } } @@ -273,7 +285,9 @@ private[spark] class ExecutorAllocationManager( */ private def maxNumExecutorsNeeded(): Int = { val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks - (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + math.ceil(numRunningOrPendingTasks * executorAllocationRatio / + tasksPerExecutorForFullParallelism) + .toInt } private def totalRunningTasks(): Int = synchronized { @@ -474,9 +488,15 @@ private[spark] class ExecutorAllocationManager( newExecutorTotal = numExistingExecutors if (testing || executorsRemoved.nonEmpty) { executorsRemoved.foreach { removedExecutorId => + // If it is a cached block, it uses cachedExecutorIdleTimeoutS for timeout + val idleTimeout = if (blockManagerMaster.hasCachedBlocks(removedExecutorId)) { + cachedExecutorIdleTimeoutS + } else { + executorIdleTimeoutS + } newExecutorTotal -= 1 logInfo(s"Removing executor $removedExecutorId because it has been idle for " + - s"$executorIdleTimeoutS seconds (new desired total will be $newExecutorTotal)") + s"$idleTimeout seconds (new desired total will be $newExecutorTotal)") executorsPendingToRemove.add(removedExecutorId) } executorsRemoved diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index ff960b396dbf1..bcbc8df0d5865 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -74,10 +74,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = - sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") private val executorTimeoutMs = - sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s") // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index f8a6f1d0d8cbb..ff85e11409e35 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,5 +23,9 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) + * @param recordsByPartitionId number of output records for each map output partition */ -private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics( + val shuffleId: Int, + val bytesByPartitionId: Array[Long], + val recordsByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 195fd4f818b36..41575ce4e6e3d 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -296,7 +296,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -434,6 +434,18 @@ private[spark] class MapOutputTrackerMaster( } } + /** Unregister all map output information of the given shuffle. */ + def unregisterAllMapOutput(shuffleId: Int) { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeOutputsByFilter(x => true) + incrementEpoch() + case None => + throw new SparkException( + s"unregisterAllMapOutput called for nonexistent shuffle ID $shuffleId.") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int) { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => @@ -510,16 +522,19 @@ private[spark] class MapOutputTrackerMaster( def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) + val recordsByMapTask = new Array[Long](statuses.length) + val parallelAggThreshold = conf.get( SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) val parallelism = math.min( Runtime.getRuntime.availableProcessors(), statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt if (parallelism <= 1) { - for (s <- statuses) { + statuses.zipWithIndex.foreach { case (s, index) => for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } + recordsByMapTask(index) = s.numberOfOutput } } else { val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") @@ -536,8 +551,11 @@ private[spark] class MapOutputTrackerMaster( } finally { threadPool.shutdown() } + statuses.zipWithIndex.foreach { case (s, index) => + recordsByMapTask(index) = s.numberOfOutput + } } - new MapOutputStatistics(dep.shuffleId, totalSizes) + new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask) } } @@ -632,9 +650,10 @@ private[spark] class MapOutputTrackerMaster( } } + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -642,7 +661,7 @@ private[spark] class MapOutputTrackerMaster( MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } case None => - Seq.empty + Iterator.empty } } @@ -669,8 +688,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -841,6 +861,7 @@ private[spark] object MapOutputTracker extends Logging { * Given an array of map statuses and a range of map output partitions, returns a sequence that, * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes * stored at that block manager. + * Note that empty blocks are filtered in the result. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. @@ -857,22 +878,24 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.zipWithIndex) { + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } - - splitsByAddress.toSeq + splitsByAddress.iterator } } diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 477b01968c6ef..1632e0c69eef5 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -21,6 +21,7 @@ import java.io.File import java.security.NoSuchAlgorithmException import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration import org.eclipse.jetty.util.ssl.SslContextFactory import org.apache.spark.internal.Logging @@ -128,7 +129,7 @@ private[spark] case class SSLOptions( } /** Returns a string representation of this SSLOptions with all the passwords masked. */ - override def toString: String = s"SSLOptions{enabled=$enabled, " + + override def toString: String = s"SSLOptions{enabled=$enabled, port=$port, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " + s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}" @@ -142,6 +143,7 @@ private[spark] object SSLOptions extends Logging { * * The following settings are allowed: * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].port` - the port where to bind the SSL server * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory * $ - `[ns].keyStorePassword` - a password to the key-store file * $ - `[ns].keyPassword` - a password to the private key @@ -162,11 +164,16 @@ private[spark] object SSLOptions extends Logging { * missing in SparkConf, the corresponding setting is used from the default configuration. * * @param conf Spark configuration object where the settings are collected from + * @param hadoopConf Hadoop configuration to get settings * @param ns the namespace name * @param defaults the default configuration * @return [[org.apache.spark.SSLOptions]] object */ - def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + def parse( + conf: SparkConf, + hadoopConf: Configuration, + ns: String, + defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) @@ -178,9 +185,11 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.keyStore)) val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyStorePassword)) val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyPassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyPassword)) val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") @@ -193,6 +202,7 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.trustStore)) val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.trustStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.trustStorePassword)) val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 09ec8932353a0..3cfafeb951105 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,18 +17,13 @@ package org.apache.spark -import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import java.security.{KeyStore, SecureRandom} -import java.security.cert.X509Certificate -import javax.net.ssl._ -import com.google.common.hash.HashCodes -import com.google.common.io.Files import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher @@ -89,6 +84,7 @@ private[spark] class SecurityManager( setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + private var secretKey: String = _ logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -115,11 +111,14 @@ private[spark] class SecurityManager( ) } + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) // the default SSL configuration - it will be used by all communication layers unless overwritten - private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) + private val defaultSSLOptions = + SSLOptions.parse(sparkConf, hadoopConf, "spark.ssl", defaults = None) def getSSLOptions(module: String): SSLOptions = { - val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) + val opts = + SSLOptions.parse(sparkConf, hadoopConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") opts } @@ -321,6 +320,12 @@ private[spark] class SecurityManager( val creds = UserGroupInformation.getCurrentUser().getCredentials() Option(creds.getSecretKey(SECRET_LOOKUP_KEY)) .map { bytes => new String(bytes, UTF_8) } + // Secret key may not be found in current UGI's credentials. + // This happens when UGI is refreshed in the driver side by UGI's loginFromKeytab but not + // copy secret key from original UGI to the new one. This exists in ThriftServer's Hive + // logic. So as a workaround, storing secret key in a local variable to make it visible + // in different context. + .orElse(Option(secretKey)) .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) .getOrElse { @@ -358,14 +363,9 @@ private[spark] class SecurityManager( return } - val rnd = new SecureRandom() - val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE - val secretBytes = new Array[Byte](length) - rnd.nextBytes(secretBytes) - + secretKey = Utils.createSecret(sparkConf) val creds = new Credentials() - val secretStr = HashCodes.fromBytes(secretBytes).toString() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) + creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 129956e9f9ffa..6c4c5c94cfa28 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -265,16 +265,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String): Long = { + def getTimeAsSeconds(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key)) } /** * Get a time parameter as seconds, falling back to a default if not set. If no * suffix is provided then seconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String, defaultValue: String): Long = { + def getTimeAsSeconds(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key, defaultValue)) } @@ -282,16 +284,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then milliseconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String): Long = { + def getTimeAsMs(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key)) } /** * Get a time parameter as milliseconds, falling back to a default if not set. If no * suffix is provided then milliseconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String, defaultValue: String): Long = { + def getTimeAsMs(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key, defaultValue)) } @@ -299,23 +303,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then bytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String): Long = { + def getSizeAsBytes(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key)) } /** * Get a size parameter as bytes, falling back to a default if not set. If no * suffix is provided then bytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: String): Long = { + def getSizeAsBytes(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue)) } /** * Get a size parameter as bytes, falling back to a default if not set. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: Long): Long = { + def getSizeAsBytes(key: String, defaultValue: Long): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue + "B")) } @@ -323,16 +330,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String): Long = { + def getSizeAsKb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key)) } /** * Get a size parameter as Kibibytes, falling back to a default if not set. If no * suffix is provided then Kibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String, defaultValue: String): Long = { + def getSizeAsKb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key, defaultValue)) } @@ -340,16 +349,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String): Long = { + def getSizeAsMb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key)) } /** * Get a size parameter as Mebibytes, falling back to a default if not set. If no * suffix is provided then Mebibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String, defaultValue: String): Long = { + def getSizeAsMb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key, defaultValue)) } @@ -357,16 +368,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String): Long = { + def getSizeAsGb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key)) } /** * Get a size parameter as Gibibytes, falling back to a default if not set. If no * suffix is provided then Gibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String, defaultValue: String): Long = { + def getSizeAsGb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key, defaultValue)) } @@ -394,23 +407,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } - /** Get a parameter as an integer, falling back to a default if not set */ - def getInt(key: String, defaultValue: Int): Int = { + /** + * Get a parameter as an integer, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as an integer + */ + def getInt(key: String, defaultValue: Int): Int = catchIllegalValue(key) { getOption(key).map(_.toInt).getOrElse(defaultValue) } - /** Get a parameter as a long, falling back to a default if not set */ - def getLong(key: String, defaultValue: Long): Long = { + /** + * Get a parameter as a long, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as a long + */ + def getLong(key: String, defaultValue: Long): Long = catchIllegalValue(key) { getOption(key).map(_.toLong).getOrElse(defaultValue) } - /** Get a parameter as a double, falling back to a default if not set */ - def getDouble(key: String, defaultValue: Double): Double = { + /** + * Get a parameter as a double, falling back to a default if not ste + * @throws NumberFormatException If the value cannot be interpreted as a double + */ + def getDouble(key: String, defaultValue: Double): Double = catchIllegalValue(key) { getOption(key).map(_.toDouble).getOrElse(defaultValue) } - /** Get a parameter as a boolean, falling back to a default if not set */ - def getBoolean(key: String, defaultValue: Boolean): Boolean = { + /** + * Get a parameter as a boolean, falling back to a default if not set + * @throws IllegalArgumentException If the value cannot be interpreted as a boolean + */ + def getBoolean(key: String, defaultValue: Boolean): Boolean = catchIllegalValue(key) { getOption(key).map(_.toBoolean).getOrElse(defaultValue) } @@ -448,14 +473,33 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria */ private[spark] def getenv(name: String): String = System.getenv(name) + /** + * Wrapper method for get() methods which require some specific value format. This catches + * any [[NumberFormatException]] or [[IllegalArgumentException]] and re-raises it with the + * incorrectly configured key in the exception message. + */ + private def catchIllegalValue[T](key: String)(getValue: => T): T = { + try { + getValue + } catch { + case e: NumberFormatException => + // NumberFormatException doesn't have a constructor that takes a cause for some reason. + throw new NumberFormatException(s"Illegal value for config key $key: ${e.getMessage}") + .initCause(e) + case e: IllegalArgumentException => + throw new IllegalArgumentException(s"Illegal value for config key $key: ${e.getMessage}", e) + } + } + /** * Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { if (contains("spark.local.dir")) { - val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + - "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." + val msg = "Note that spark.local.dir will be overridden by the value set by " + + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS" + + " in YARN)." logWarning(msg) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5e8595603cc90..e5b1e0ecd1586 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -254,7 +254,7 @@ class SparkContext(config: SparkConf) extends Logging { conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { - SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master)) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master, conf)) } private[spark] def env: SparkEnv = _env @@ -571,7 +571,12 @@ class SparkContext(config: SparkConf) extends Logging { _shutdownHookRef = ShutdownHookManager.addShutdownHook( ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") - stop() + try { + stop() + } catch { + case e: Throwable => + logWarning("Ignoring Exception while stopping SparkContext from shutdown hook", e) + } } } catch { case NonFatal(e) => @@ -1306,11 +1311,12 @@ class SparkContext(config: SparkConf) extends Logging { /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope { - val partitioners = rdds.flatMap(_.partitioner).toSet - if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { - new PartitionerAwareUnionRDD(this, rdds) + val nonEmptyRdds = rdds.filter(!_.partitions.isEmpty) + val partitioners = nonEmptyRdds.flatMap(_.partitioner).toSet + if (nonEmptyRdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, nonEmptyRdds) } else { - new UnionRDD(this, rdds) + new UnionRDD(this, nonEmptyRdds) } } @@ -1495,6 +1501,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String): Unit = { addFile(path, false) @@ -1515,11 +1523,17 @@ class SparkContext(config: SparkConf) extends Logging { * use `SparkFiles.get(fileName)` to find its download location. * @param recursive if true, a directory can be given in `path`. Currently directories are * only supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri val schemeCorrectedPath = uri.getScheme match { - case null | "local" => new File(path).getCanonicalFile.toURI.toString + case null => new File(path).getCanonicalFile.toURI.toString + case "local" => + logWarning("File with 'local' scheme is not supported to add to file server, since " + + "it is already available on every node.") + return case _ => path } @@ -1554,6 +1568,9 @@ class SparkContext(config: SparkConf) extends Logging { Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConfiguration, timestamp, useCache = false) postEnvironmentUpdate() + } else { + logWarning(s"The path $path has been added already. Overwriting of added paths " + + "is not supported in the current version.") } } @@ -1585,6 +1602,15 @@ class SparkContext(config: SparkConf) extends Logging { } } + /** + * Get the max number of tasks that can be concurrent launched currently. + * Note that please don't cache the value returned by this method, because the number can change + * due to add/remove executors. + * + * @return The max number of tasks that can be concurrent launched currently. + */ + private[spark] def maxNumConcurrentTasks(): Int = schedulerBackend.maxNumConcurrentTasks() + /** * Update the cluster manager on our scheduling needs. Three bits of information are included * to help it make decisions. @@ -1802,6 +1828,8 @@ class SparkContext(config: SparkConf) extends Logging { * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { def addJarFile(file: File): String = { @@ -1848,6 +1876,9 @@ class SparkContext(config: SparkConf) extends Logging { if (addedJars.putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added JAR $path at $key with timestamp $timestamp") postEnvironmentUpdate() + } else { + logWarning(s"The jar $path has been added already. Overwriting of added jars " + + "is not supported in the current version.") } } } @@ -1913,6 +1944,12 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _executorAllocationManager.foreach(_.stop()) } + if (_dagScheduler != null) { + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } + _dagScheduler = null + } if (_listenerBusStarted) { Utils.tryLogNonFatalError { listenerBus.stop() @@ -1922,12 +1959,6 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } - if (_dagScheduler != null) { - Utils.tryLogNonFatalError { - _dagScheduler.stop() - } - _dagScheduler = null - } if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) @@ -2651,9 +2682,16 @@ object SparkContext extends Logging { } /** - * The number of driver cores to use for execution in local mode, 0 otherwise. + * The number of cores available to the driver to use for tasks such as I/O with Netty */ private[spark] def numDriverCores(master: String): Int = { + numDriverCores(master, null) + } + + /** + * The number of cores available to the driver to use for tasks such as I/O with Netty + */ + private[spark] def numDriverCores(master: String, conf: SparkConf): Int = { def convertToInt(threads: String): Int = { if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt } @@ -2661,7 +2699,13 @@ object SparkContext extends Logging { case "local" => 1 case SparkMasterRegex.LOCAL_N_REGEX(threads) => convertToInt(threads) case SparkMasterRegex.LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) - case _ => 0 // driver is not used for execution + case "yarn" => + if (conf != null && conf.getOption("spark.submit.deployMode").contains("cluster")) { + conf.getInt("spark.driver.cores", 0) + } else { + 0 + } + case _ => 0 // Either driver is not being used, or its core count will be interpolated later } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 69739745aa6cf..ceadf108c86cd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -123,7 +123,10 @@ abstract class TaskContext extends Serializable { * * Exceptions thrown by the listener will result in failure of the task. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { + def addTaskCompletionListener[U](f: (TaskContext) => U): TaskContext = { + // Note that due to this scala bug: https://github.com/scala/bug/issues/11016, we need to make + // this function polymorphic for every scala version >= 2.12, otherwise an overloaded method + // resolution error occurs at compile time. addTaskCompletionListener(new TaskCompletionListener { override def onTaskCompletion(context: TaskContext): Unit = f(context) }) diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index a76283e33fa65..33901bc8380e9 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case class TaskKilled(reason: String) extends TaskFailedReason { +case class TaskKilled( + reason: String, + accumUpdates: Seq[AccumulableInfo] = Seq.empty, + private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil) + extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false + } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index b5c4c705dcbc7..c2ebd388a2365 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate import java.util.{Arrays, Properties} -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -172,22 +172,24 @@ private[spark] object TestUtils { /** * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ - def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { + val listener = new SpillListener + withListener(sc, listener) { _ => + body + } + assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not") } /** * Run some code involving jobs submitted to the given context and assert that the jobs * did not spill. */ - def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { + val listener = new SpillListener + withListener(sc, listener) { _ => + body + } + assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } /** @@ -233,6 +235,21 @@ private[spark] object TestUtils { } } + /** + * Runs some code with the given listener installed in the SparkContext. After the code runs, + * this method will wait until all events posted to the listener bus are processed, and then + * remove the listener from the bus. + */ + def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = { + sc.addSparkListener(listener) + try { + body(listener) + } finally { + sc.listenerBus.waitUntilEmpty(TimeUnit.SECONDS.toMillis(10)) + sc.listenerBus.removeListener(listener) + } + } + /** * Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting * time elapsed before `numExecutors` executors up. Exposed for testing. @@ -289,21 +306,17 @@ private[spark] object TestUtils { private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] private val spilledStageIds = new mutable.HashSet[Int] - private val stagesDone = new CountDownLatch(1) - def numSpilledStages: Int = { - // Long timeout, just in case somehow the job end isn't notified. - // Fails if a timeout occurs - assert(stagesDone.await(10, TimeUnit.SECONDS)) + def numSpilledStages: Int = synchronized { spilledStageIds.size } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { stageIdToTaskMetrics.getOrElseUpdate( taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics } - override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = { + override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = synchronized { val stageId = stageComplete.stageInfo.stageId val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten val spilled = metrics.map(_.memoryBytesSpilled).sum > 0 @@ -311,8 +324,4 @@ private class SpillListener extends SparkListener { spilledStageIds += stageId } } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - stagesDone.countDown() - } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index f1936bf587282..09c83849e26b2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -668,6 +668,8 @@ class JavaSparkContext(val sc: SparkContext) * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String) { sc.addFile(path) @@ -681,6 +683,8 @@ class JavaSparkContext(val sc: SparkContext) * * A directory can be given if the recursive option is set to true. Currently directories are only * supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { sc.addFile(path, recursive) @@ -690,6 +694,8 @@ class JavaSparkContext(val sc: SparkContext) * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { sc.addJar(path) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 11f2432575d84..9ddc4a4910180 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -17,26 +17,39 @@ package org.apache.spark.api.python -import java.io.DataOutputStream -import java.net.Socket +import java.io.{DataOutputStream, File, FileOutputStream} +import java.net.InetAddress +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files import py4j.GatewayServer +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port - * back to its caller via a callback port specified by the caller. + * Process that starts a Py4J GatewayServer on an ephemeral port. * * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). */ private[spark] object PythonGatewayServer extends Logging { initializeLogIfNecessary(true) - def main(args: Array[String]): Unit = Utils.tryOrExit { - // Start a GatewayServer on an ephemeral port - val gatewayServer: GatewayServer = new GatewayServer(null, 0) + def main(args: Array[String]): Unit = { + val secret = Utils.createSecret(new SparkConf()) + + // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured + // with the same secret, in case the app needs callbacks from the JVM to the underlying + // python processes. + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() + gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort if (boundPort == -1) { @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging { logDebug(s"Started PythonGatewayServer on port $boundPort") } - // Communicate the bound port back to the caller via the caller-specified callback port - val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") - val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt - logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") - val callbackSocket = new Socket(callbackHost, callbackPort) - val dos = new DataOutputStream(callbackSocket.getOutputStream) + // Communicate the connection information back to the python process by writing the + // information in the requested file. This needs to match the read side in java_gateway.py. + val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH")) + val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(), + "connection", ".info").toFile() + + val dos = new DataOutputStream(new FileOutputStream(tmpPath)) dos.writeInt(boundPort) + + val secretBytes = secret.getBytes(UTF_8) + dos.writeInt(secretBytes.length) + dos.write(secretBytes, 0, secretBytes.length) dos.close() - callbackSocket.close() + + if (!tmpPath.renameTo(connectionInfoPath)) { + logError(s"Unable to write connection information to $connectionInfoPath.") + System.exit(1) + } // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: while (System.in.read() != -1) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f6293c0dc5091..e639a842754bd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -38,18 +38,17 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ private[spark] class PythonRDD( parent: RDD[_], func: PythonFunction, - preservePartitoning: Boolean) + preservePartitoning: Boolean, + isFromBarrier: Boolean = false) extends RDD[Array[Byte]](parent) { - val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions: Array[Partition] = firstParent.partitions override val partitioner: Option[Partitioner] = { @@ -59,9 +58,12 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func, bufferSize, reuseWorker) + val runner = PythonRunner(func) runner.compute(firstParent.iterator(split, context), split.index, context) } + + @transient protected lazy override val isBarrier_ : Boolean = + isFromBarrier || dependencies.exists(_.rdd.isBarrier()) } /** @@ -107,6 +109,12 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + // Authentication helper used when serving iterator data. + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new SocketAuthHelper(conf) + } + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) @@ -129,12 +137,13 @@ private[spark] object PythonRDD extends Logging { * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int]): Int = { + partitions: JArrayList[Int]): Array[Any] = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = @@ -147,13 +156,14 @@ private[spark] object PythonRDD extends Logging { /** * A helper function to collect an RDD as an iterator, then serve it via socket. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def collectAndServe[T](rdd: RDD[T]): Int = { + def collectAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } - def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") } @@ -384,8 +394,31 @@ private[spark] object PythonRDD extends Logging { * and send them into this connection. * * The thread will terminate after all the data are sent or any exceptions happen. + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { + serveToStream(threadName) { out => + writeIteratorToStream(items, new DataOutputStream(out)) + } + } + + /** + * Create a socket server and background thread to execute the writeFunc + * with the given OutputStream. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * Once a connection comes in, it will execute the block of code and pass in + * the socket output stream. + * + * The thread will terminate after the block of code is executed or any + * exceptions happen. + */ + private[spark] def serveToStream( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -395,11 +428,14 @@ private[spark] object PythonRDD extends Logging { override def run() { try { val sock = serverSocket.accept() - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + authHelper.authClient(sock) + + val out = new BufferedOutputStream(sock.getOutputStream) Utils.tryWithSafeFinally { - writeIteratorToStream(items, out) + writeFunc(out) } { out.close() + sock.close() } } catch { case NonFatal(e) => @@ -410,7 +446,7 @@ private[spark] object PythonRDD extends Logging { } }.start() - serverSocket.getLocalPort + Array(serverSocket.getLocalPort, authHelper.secret) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], @@ -571,8 +607,9 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By */ private[spark] class PythonAccumulatorV2( @transient private val serverHost: String, - private val serverPort: Int) - extends CollectionAccumulator[Array[Byte]] { + private val serverPort: Int, + private val secretToken: String) + extends CollectionAccumulator[Array[Byte]] with Logging{ Utils.checkHost(serverHost) @@ -587,17 +624,22 @@ private[spark] class PythonAccumulatorV2( private def openSocket(): Socket = synchronized { if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) + logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort") + // send the secret just for the initial authentication when opening a new connection + socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) } socket } // Need to override so the types match with PythonFunction - override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort) + override def copyAndReset(): PythonAccumulatorV2 = { + new PythonAccumulatorV2(serverHost, serverPort, secretToken) + } override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] // This conditional isn't strictly speaking needed - merging only currently happens on the - // driver program - but that isn't gauranteed so incase this changes. + // driver program - but that isn't guaranteed so incase this changes. if (serverHost == null) { // We are on the worker super.merge(otherPythonAccumulator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f075a7e0eb0b4..6c7e8630789bd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -20,12 +20,15 @@ package org.apache.spark.api.python import java.io._ import java.net._ import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -40,6 +43,7 @@ private[spark] object PythonEvalType { val SQL_SCALAR_PANDAS_UDF = 200 val SQL_GROUPED_MAP_PANDAS_UDF = 201 val SQL_GROUPED_AGG_PANDAS_UDF = 202 + val SQL_WINDOW_AGG_PANDAS_UDF = 203 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -47,6 +51,7 @@ private[spark] object PythonEvalType { case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" + case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" } } @@ -58,14 +63,20 @@ private[spark] object PythonEvalType { */ private[spark] abstract class BasePythonRunner[IN, OUT]( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) extends Logging { require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + private val conf = SparkEnv.get.conf + private val bufferSize = conf.getInt("spark.buffer.size", 65536) + private val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + // each python worker gets an equal part of the allocation. the worker pool will grow to the + // number of concurrent tasks, which is determined by the number of cores in this executor. + private val memoryMb = conf.get(PYSPARK_EXECUTOR_MEMORY) + .map(_ / conf.getInt("spark.executor.cores", 1)) + // All the Python functions should have the same exec, version and envvars. protected val envVars = funcs.head.funcs.head.envVars protected val pythonExec = funcs.head.funcs.head.pythonExec @@ -74,6 +85,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // TODO: support accumulator in multiple UDF protected val accumulator = funcs.head.funcs.head.accumulator + // Expose a ServerSocket to support method calls via socket from Python side. + private[spark] var serverSocket: Option[ServerSocket] = None + + // Authentication helper used when serving method calls via socket from Python side. + private lazy val authHelper = new SocketAuthHelper(conf) + def compute( inputIterator: Iterator[IN], partitionIndex: Int, @@ -85,6 +102,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } + if (memoryMb.isDefined) { + envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", memoryMb.get.toString) + } val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool val released = new AtomicBoolean(false) @@ -92,7 +112,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Start a thread to feed the process input from our parent's iterator val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => writerThread.shutdownOnTaskCompletion() if (!reuseWorker || !released.get) { try { @@ -178,11 +198,85 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) + // Init a ServerSocket to accept method calls from Python side. + val isBarrier = context.isInstanceOf[BarrierTaskContext] + if (isBarrier) { + serverSocket = Some(new ServerSocket(/* port */ 0, + /* backlog */ 1, + InetAddress.getByName("localhost"))) + // A call to accept() for ServerSocket shall block infinitely. + serverSocket.map(_.setSoTimeout(0)) + new Thread("accept-connections") { + setDaemon(true) + + override def run(): Unit = { + while (!serverSocket.get.isClosed()) { + var sock: Socket = null + try { + sock = serverSocket.get.accept() + // Wait for function call from python side. + sock.setSoTimeout(10000) + authHelper.authClient(sock) + val input = new DataInputStream(sock.getInputStream()) + input.readInt() match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + // The barrier() function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + barrierAndServe(sock) + + case _ => + val out = new DataOutputStream(new BufferedOutputStream( + sock.getOutputStream)) + writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) + } + } catch { + case e: SocketException if e.getMessage.contains("Socket closed") => + // It is possible that the ServerSocket is not closed, but the native socket + // has already been closed, we shall catch and silently ignore this case. + } finally { + if (sock != null) { + sock.close() + } + } + } + } + }.start() + } + val secret = if (isBarrier) { + authHelper.secret + } else { + "" + } + // Close ServerSocket on task completion. + serverSocket.foreach { server => + context.addTaskCompletionListener[Unit](_ => server.close()) + } + val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) + if (boundPort == -1) { + val message = "ServerSocket failed to bind to Java side." + logError(message) + throw new SparkException(message) + } else if (isBarrier) { + logDebug(s"Started ServerSocket on port $boundPort.") + } // Write out the TaskContextInfo + dataOut.writeBoolean(isBarrier) + dataOut.writeInt(boundPort) + val secretBytes = secret.getBytes(UTF_8) + dataOut.writeInt(secretBytes.length) + dataOut.write(secretBytes, 0, secretBytes.length) dataOut.writeInt(context.stageId()) dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) + val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) @@ -234,6 +328,30 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } } } + + /** + * Gateway to call BarrierTaskContext.barrier(). + */ + def barrierAndServe(sock: Socket): Unit = { + require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") + + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + try { + context.asInstanceOf[BarrierTaskContext].barrier() + writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) + } catch { + case e: SparkException => + writeUTF(e.getMessage, out) + } finally { + out.close() + } + } + + def writeUTF(str: String, dataOut: DataOutputStream) { + val bytes = str.getBytes(UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } } abstract class ReaderIterator( @@ -376,20 +494,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = { - new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + def apply(func: PythonFunction): PythonRunner = { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func)))) } } /** * A helper class to run Python mapPartition in Spark. */ -private[spark] class PythonRunner( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean) +private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + funcs, PythonEvalType.NON_UDF, Array(Array(0))) { protected override def newWriterThread( env: SparkEnv, @@ -456,3 +571,9 @@ private[spark] object SpecialLengths { val NULL = -5 val START_ARROW_STREAM = -6 } + +private[spark] object BarrierTaskContextMessageProtocol { + val BARRIER_FUNCTION = 1 + val BARRIER_RESULT_SUCCESS = "success" + val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 92e228a9dd10c..27a5e19f96a14 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 2340580b54f67..6afa37aa36fd3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -27,6 +27,7 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) @@ -67,6 +68,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String value }.getOrElse("pyspark.worker") + private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -108,6 +111,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } + + authHelper.authToServer(socket) daemonWorkers.put(socket, pid) socket } @@ -145,25 +150,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) val worker = pb.start() // Redirect worker stdout and stderr redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) - // Tell the worker our port - val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8) - out.write(serverSocket.getLocalPort + "\n") - out.flush() - - // Wait for it to connect to our socket + // Wait for it to connect to our socket, and validate the auth secret. serverSocket.setSoTimeout(10000) + try { val socket = serverSocket.accept() + authHelper.authClient(socket) simpleWorkers.put(socket, worker) return socket } catch { case e: Exception => - throw new SparkException("Python worker did not connect back in time", e) + throw new SparkException("Python worker failed to connect back.", e) } } finally { if (serverSocket != null) { @@ -187,6 +191,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() @@ -218,7 +223,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) - } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala new file mode 100644 index 0000000000000..ac6826a9ec774 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.SparkConf +import org.apache.spark.security.SocketAuthHelper + +private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) { + + override protected def readUtf8(s: Socket): String = { + SerDe.readString(new DataInputStream(s.getInputStream())) + } + + override protected def writeUtf8(str: String, s: Socket): Unit = { + val out = s.getOutputStream() + SerDe.writeString(new DataOutputStream(out), str) + out.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 2d1152a036449..7ce2581555014 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.r -import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetAddress, InetSocketAddress, ServerSocket} +import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -45,7 +47,7 @@ private[spark] class RBackend { /** Tracks JVM objects returned to R for this RBackend instance. */ private[r] val jvmObjectTracker = new JVMObjectTracker - def init(): Int = { + def init(): (Int, RAuthHelper) = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) @@ -53,6 +55,7 @@ private[spark] class RBackend { conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) + val authHelper = new RAuthHelper(conf) bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) @@ -71,13 +74,16 @@ private[spark] class RBackend { new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) + .addLast(new RBackendAuthHandler(authHelper.secret)) .addLast("handler", handler) } }) channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() - channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + + val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + (port, authHelper) } def run(): Unit = { @@ -90,11 +96,11 @@ private[spark] class RBackend { channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) channelFuture = null } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully() } if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() + bootstrap.config().childGroup().shutdownGracefully() } bootstrap = null jvmObjectTracker.clear() @@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging { val sparkRBackend = new RBackend() try { // bind to random port - val boundPort = sparkRBackend.init() + val (boundPort, authHelper) = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // Connection timeout is set by socket client. To make it configurable we will pass the @@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.writeInt(backendConnectionTimeout) + SerDe.writeString(dos, authHelper.secret) dos.close() f.renameTo(new File(path)) @@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging { val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) + + // Wait for the R process to connect back, ignoring any failed auth attempts. Allow + // a max number of connection attempts to avoid looping forever. try { - val inSocket = serverSocket.accept() + var remainingAttempts = 10 + var inSocket: Socket = null + while (inSocket == null) { + inSocket = serverSocket.accept() + try { + authHelper.authClient(inSocket) + } catch { + case e: Exception => + remainingAttempts -= 1 + if (remainingAttempts == 0) { + val msg = "Too many failed authentication attempts." + logError(msg) + throw new IllegalStateException(msg) + } + logInfo("Client connection failed authentication.") + inSocket = null + } + } + serverSocket.close() + // wait for the end of socket, closed if R process die inSocket.getInputStream().read(buf) } finally { + serverSocket.close() sparkRBackend.close() System.exit(0) } @@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging { } System.exit(0) } + } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala new file mode 100644 index 0000000000000..4162e4a6c7476 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.charset.StandardCharsets.UTF_8 + +import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Authentication handler for connections from the R process. + */ +private class RBackendAuthHandler(secret: String) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + // The R code adds a null terminator to serialized strings, so ignore it here. + val clientSecret = new String(msg, 0, msg.length - 1, UTF_8) + try { + require(secret == clientSecret, "Auth secret mismatch.") + ctx.pipeline().remove(this) + writeReply("ok", ctx.channel()) + } catch { + case e: Exception => + logInfo("Authentication failure.", e) + writeReply("err", ctx.channel()) + ctx.close() + } + } + + private def writeReply(reply: String, chan: Channel): Unit = { + val out = new ByteArrayOutputStream() + SerDe.writeString(new DataOutputStream(out), reply) + chan.writeAndFlush(out.toByteArray()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 88118392003e8..e7fdc3963945a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -74,14 +74,19 @@ private[spark] class RRunner[U]( // the socket used to send out the input of task serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() + dataStream = try { + val inSocket = serverSocket.accept() + RRunner.authHelper.authClient(inSocket) + startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + RRunner.authHelper.authClient(outSocket) + val inputStream = new BufferedInputStream(outSocket.getInputStream) + new DataInputStream(inputStream) + } finally { + serverSocket.close() + } try { return new Iterator[U] { @@ -315,6 +320,11 @@ private[r] object RRunner { private[this] var errThread: BufferedStreamThread = _ private[this] var daemonChannel: DataOutputStream = _ + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new RAuthHelper(conf) + } + /** * Start a thread to print the process's stderr to ours */ @@ -349,6 +359,7 @@ private[r] object RRunner { pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") + pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) @@ -370,8 +381,12 @@ private[r] object RRunner { // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() + try { + authHelper.authClient(sock) + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + } finally { + serverSocket.close() + } } try { daemonChannel.writeInt(port) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e125095cf4777..cbd49e070f2eb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -262,7 +262,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) val blockManager = SparkEnv.get.blockManager Option(TaskContext.get()) match { case Some(taskContext) => - taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) + taskContext.addTaskCompletionListener[Unit](_ => blockManager.releaseLock(blockId)) case None => // This should only happen on the driver, where broadcast variables may be accessed // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index f975fa5cb4e23..b59a4fe66587c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -94,6 +94,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana blockHandler.applicationRemoved(appId, true /* cleanupLocalDirs */) } + /** Clean up all the non-shuffle files associated with an executor that has exited. */ + def executorRemoved(executorId: String, appId: String): Unit = { + blockHandler.executorRemoved(executorId, appId) + } + def stop() { if (server != null) { server.close() diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 7aca305783a7f..ccb30e205ca40 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,7 +18,8 @@ package org.apache.spark.deploy import java.io.File -import java.net.URI +import java.net.{InetAddress, URI} +import java.nio.file.Files import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -39,6 +40,7 @@ object PythonRunner { val pyFiles = args(1) val otherArgs = args.slice(2, args.length) val sparkConf = new SparkConf() + val secret = Utils.createSecret(sparkConf) val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) .orElse(sparkConf.get(PYSPARK_PYTHON)) .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) @@ -47,11 +49,17 @@ object PythonRunner { // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) - val formattedPyFiles = formatPaths(pyFiles) + val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such - val gatewayServer = new py4j.GatewayServer(null, 0) + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() val thread = new Thread(new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions { gatewayServer.start() @@ -82,6 +90,7 @@ object PythonRunner { // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) + env.put("PYSPARK_GATEWAY_SECRET", secret) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) @@ -145,4 +154,30 @@ object PythonRunner { .map { p => formatPath(p, testWindows) } } + /** + * Resolves the ".py" files. ".py" file should not be added as is because PYTHONPATH does + * not expect a file. This method creates a temporary directory and puts the ".py" files + * if exist in the given paths. + */ + private def resolvePyFiles(pyFiles: Array[String]): Array[String] = { + lazy val dest = Utils.createTempDir(namePrefix = "localPyFiles") + pyFiles.flatMap { pyFile => + // In case of client with submit, the python paths should be set before context + // initialization because the context initialization can be done later. + // We will copy the local ".py" files because ".py" file shouldn't be added + // alone but its parent directory in PYTHONPATH. See SPARK-24384. + if (pyFile.endsWith(".py")) { + val source = new File(pyFile) + if (source.exists() && source.isFile && source.canRead) { + Files.copy(source.toPath, new File(dest, source.getName).toPath) + Some(dest.getAbsolutePath) + } else { + // Don't have to add it if it doesn't exist or isn't readable. + None + } + } else { + Some(pyFile) + } + }.distinct + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 6eb53a8252205..e86b362639e57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -68,10 +68,13 @@ object RRunner { // Java system properties etc. val sparkRBackend = new RBackend() @volatile var sparkRBackendPort = 0 + @volatile var sparkRBackendSecret: String = null val initialized = new Semaphore(0) val sparkRBackendThread = new Thread("SparkR backend") { override def run() { - sparkRBackendPort = sparkRBackend.init() + val (port, authHelper) = sparkRBackend.init() + sparkRBackendPort = port + sparkRBackendSecret = authHelper.secret initialized.release() sparkRBackend.run() } @@ -91,6 +94,7 @@ object RRunner { env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 8353e64a619cf..4cc0063d010ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -31,7 +31,6 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} -import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -108,7 +107,7 @@ class SparkHadoopUtil extends Logging { } /** - * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * Return an appropriate (subclass) of Configuration. Creating config can initialize some Hadoop * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { @@ -367,28 +366,6 @@ class SparkHadoopUtil extends Logging { buffer.toString } - private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { - val perm = status.getPermission - val ugi = UserGroupInformation.getCurrentUser - - if (ugi.getShortUserName == status.getOwner) { - if (perm.getUserAction.implies(mode)) { - return true - } - } else if (ugi.getGroupNames.contains(status.getGroup)) { - if (perm.getGroupAction.implies(mode)) { - return true - } - } else if (perm.getOtherAction.implies(mode)) { - return true - } - - logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + - s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + - s"${if (status.isDirectory) "d" else "-"}$perm") - false - } - def serialize(creds: Credentials): Array[Byte] = { val byteStream = new ByteArrayOutputStream val dataStream = new DataOutputStream(byteStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 427c797755b84..cf902db8709e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -22,6 +22,7 @@ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowab import java.net.URL import java.security.PrivilegedExceptionAction import java.text.ParseException +import java.util.UUID import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -180,6 +181,7 @@ private[spark] class SparkSubmit extends Logging { if (args.isStandaloneCluster && args.useRest) { try { logInfo("Running Spark using the REST application submission protocol.") + doRunMain() } catch { // Fail over to use the legacy submission gateway case e: SubmitRestConnectionException => @@ -284,10 +286,6 @@ private[spark] class SparkSubmit extends Logging { case (STANDALONE, CLUSTER) if args.isR => error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") - case (KUBERNETES, _) if args.isPython => - error("Python applications are currently not supported for Kubernetes.") - case (KUBERNETES, _) if args.isR => - error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => error("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => @@ -309,6 +307,7 @@ private[spark] class SparkSubmit extends Logging { val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER + val isMesosClient = clusterManager == MESOS && deployMode == CLIENT if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files @@ -336,7 +335,7 @@ private[spark] class SparkSubmit extends Logging { val targetDir = Utils.createTempDir() // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient) { if (args.principal != null) { if (args.keytab != null) { require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") @@ -385,7 +384,7 @@ private[spark] class SparkSubmit extends Logging { val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) def shouldDownload(scheme: String): Boolean = { - forceDownloadSchemes.contains(scheme) || + forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) || Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure } @@ -428,18 +427,15 @@ private[spark] class SparkSubmit extends Logging { // Usage: PythonAppRunner

[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs - if (clusterManager != YARN) { - // The YARN backend distributes the primary file differently, so don't merge it. - args.files = mergeFileLists(args.files, args.primaryResource) - } } if (clusterManager != YARN) { // The YARN backend handles python files differently, so don't merge the lists. args.files = mergeFileLists(args.files, args.pyFiles) } - if (localPyFiles != null) { - sparkConf.set("spark.submit.pyFiles", localPyFiles) - } + } + + if (localPyFiles != null) { + sparkConf.set("spark.submit.pyFiles", localPyFiles) } // In YARN mode for an R app, add the SparkR package archive and the R package @@ -581,7 +577,8 @@ private[spark] class SparkSubmit extends Logging { } // Add the main application jar and any added jars to classpath in case YARN client // requires these jars. - // This assumes both primaryResource and user jars are local jars, otherwise it will not be + // This assumes both primaryResource and user jars are local jars, or already downloaded + // to local by configuring "spark.yarn.dist.forceDownloadSchemes", otherwise it will not be // added to the classpath of YARN client. if (isYarnCluster) { if (isUserJar(args.primaryResource)) { @@ -695,9 +692,23 @@ private[spark] class SparkSubmit extends Logging { if (isKubernetesCluster) { childMainClass = KUBERNETES_CLUSTER_SUBMIT_CLASS if (args.primaryResource != SparkLauncher.NO_RESOURCE) { - childArgs ++= Array("--primary-java-resource", args.primaryResource) + if (args.isPython) { + childArgs ++= Array("--primary-py-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.PythonRunner") + if (args.pyFiles != null) { + childArgs ++= Array("--other-py-files", args.pyFiles) + } + } else if (args.isR) { + childArgs ++= Array("--primary-r-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.RRunner") + } + else { + childArgs ++= Array("--primary-java-resource", args.primaryResource) + childArgs ++= Array("--main-class", args.mainClass) + } + } else { + childArgs ++= Array("--main-class", args.mainClass) } - childArgs ++= Array("--main-class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) @@ -1204,7 +1215,33 @@ private[spark] object SparkSubmitUtils { /** A nice function to use in tests as well. Values are dummy strings. */ def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( - ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + // Include UUID in module name, so multiple clients resolving maven coordinate at the same time + // do not modify the same resolution file concurrently. + ModuleRevisionId.newInstance("org.apache.spark", + s"spark-submit-parent-${UUID.randomUUID.toString}", + "1.0")) + + /** + * Clear ivy resolution from current launch. The resolution file is usually at + * ~/.ivy2/org.apache.spark-spark-submit-parent-$UUID-default.xml, + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.xml, and + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.properties. + * Since each launch will have its own resolution files created, delete them after + * each resolution to prevent accumulation of these files in the ivy cache dir. + */ + private def clearIvyResolutionFiles( + mdId: ModuleRevisionId, + ivySettings: IvySettings, + ivyConfName: String): Unit = { + val currentResolutionFiles = Seq( + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.properties" + ) + currentResolutionFiles.foreach { filename => + new File(ivySettings.getDefaultCache, filename).delete() + } + } /** * Resolves any dependencies that were supplied through maven coordinates @@ -1255,14 +1292,6 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor - // clear ivy resolution from previous launches. The resolution file is usually at - // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file - // leads to confusion with Ivy when the files can no longer be found at the repository - // declared in that file/ - val mdId = md.getModuleRevisionId - val previousResolution = new File(ivySettings.getDefaultCache, - s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") - if (previousResolution.exists) previousResolution.delete md.setDefaultConf(ivyConfName) @@ -1283,7 +1312,10 @@ private[spark] object SparkSubmitUtils { packagesDirectory.getAbsolutePath + File.separator + "[organization]_[artifact]-[revision](-[classifier]).[ext]", retrieveOptions.setConfs(Array(ivyConfName))) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val paths = resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val mdId = md.getModuleRevisionId + clearIvyResolutionFiles(mdId, ivySettings, ivyConfName) + paths } finally { System.setOut(sysOut) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0733fdb72cafb..0998757715457 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -36,7 +36,6 @@ import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils - /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. @@ -76,13 +75,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var proxyUser: String = null var principal: String = null var keytab: String = null + private var dynamicAllocationEnabled: Boolean = false // Standalone cluster mode only var supervise: Boolean = false var driverCores: String = null var submissionToKill: String = null var submissionToRequestStatusFor: String = null - var useRest: Boolean = true // used internally + var useRest: Boolean = false // used internally /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { @@ -115,6 +115,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() + useRest = sparkProperties.getOrElse("spark.master.rest.enabled", "false").toBoolean + validateArguments() /** @@ -182,6 +184,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull + pyFiles = Option(pyFiles).orElse(sparkProperties.get("spark.submit.pyFiles")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull @@ -198,6 +201,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + dynamicAllocationEnabled = + sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase) // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -274,12 +279,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { error("Total executor cores must be a positive number") } - if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { + if (!dynamicAllocationEnabled && + numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } - if (pyFiles != null && !isPython) { - error("--py-files given but primary resource is not a Python script") - } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 56db9359e033f..44d23908146c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,12 +18,15 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} +import java.nio.file.Files +import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} -import java.util.concurrent.{ExecutorService, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.ExecutionException import scala.io.Source import scala.util.Try import scala.xml.Node @@ -31,8 +34,7 @@ import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams import com.google.common.util.concurrent.MoreExecutors -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -112,7 +114,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) "; groups with admin permissions" + HISTORY_UI_ADMIN_ACLS_GROUPS.toString) private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private val fs = new Path(logDir).getFileSystem(hadoopConf) + // Visible for testing + private[history] val fs: FileSystem = new Path(logDir).getFileSystem(hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -130,8 +133,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - require(path.isDirectory(), s"Configured store directory ($path) does not exist.") - val dbPath = new File(path, "listing.ldb") + val perms = PosixFilePermissions.fromString("rwx------") + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), + PosixFilePermissions.asFileAttribute(perms)).toFile() + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) @@ -157,6 +162,25 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) new HistoryServerDiskManager(conf, path, listing, clock) } + private val blacklist = new ConcurrentHashMap[String, Long] + + // Visible for testing + private[history] def isBlacklisted(path: Path): Boolean = { + blacklist.containsKey(path.getName) + } + + private def blacklist(path: Path): Unit = { + blacklist.put(path.getName, clock.getTimeMillis()) + } + + /** + * Removes expired entries in the blacklist, according to the provided `expireTimeInSeconds`. + */ + private def clearBlacklist(expireTimeInSeconds: Long): Unit = { + val expiredThreshold = clock.getTimeMillis() - expireTimeInSeconds * 1000 + blacklist.asScala.retain((_, creationTime) => creationTime >= expiredThreshold) + } + private val activeUIs = new mutable.HashMap[(String, Option[String]), LoadedAppUI]() /** @@ -414,7 +438,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + !isBlacklisted(entry.getPath) } .filter { entry => try { @@ -457,32 +481,37 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - val tasks = updated.map { entry => + val tasks = updated.flatMap { entry => try { - replayExecutor.submit(new Runnable { + val task: Future[Unit] = replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true) - }) + }, Unit) + Some(task -> entry.getPath) } catch { // let the iteration over the updated entries break, since an exception on // replayExecutor.submit (..) indicates the ExecutorService is unable // to take any more submissions at this time case e: Exception => logError(s"Exception while submitting event log for replay", e) - null + None } - }.filter(_ != null) + } pendingReplayTasksCount.addAndGet(tasks.size) // Wait for all tasks to finish. This makes sure that checkForLogs // is not scheduled again while some tasks are already running in // the replayExecutor. - tasks.foreach { task => + tasks.foreach { case (task, path) => try { task.get() } catch { case e: InterruptedException => throw e + case e: ExecutionException if e.getCause.isInstanceOf[AccessControlException] => + // We don't have read permissions on the log file + logWarning(s"Unable to read log $path", e.getCause) + blacklist(path) case e: Exception => logError("Exception while merging application listings", e) } finally { @@ -775,6 +804,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) listing.delete(classOf[LogInfo], log.logPath) } } + // Clean the blacklist from the expired entries. + clearBlacklist(CLEAN_INTERVAL_S) } /** @@ -934,13 +965,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } private def deleteLog(log: Path): Unit = { - try { - fs.delete(log, true) - } catch { - case _: AccessControlException => - logInfo(s"No permission to delete $log, ignoring.") - case ioe: IOException => - logError(s"IOException in cleaning $log", ioe) + if (isBlacklisted(log)) { + logDebug(s"Skipping deleting $log as we don't have permissions on it.") + } else { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6fc12d721e6f1..32667ddf5c7ea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,8 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - ++ - + ++ +
- UIUtils.basicSparkPage(content, "History Server", true) + UIUtils.basicSparkPage(request, content, "History Server", true) } - private def makePageLink(showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) + private def makePageLink(request: HttpServletRequest, showIncomplete: Boolean): String = { + UIUtils.prependBaseUri(request, "/?" + "showIncomplete=" + showIncomplete) } private def isApplicationCompleted(appInfo: ApplicationInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 611fa563a7cd9..56f3f59504a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -87,7 +87,7 @@ class HistoryServer( if (!loadAppUi(appId, None) && (!attemptId.isDefined || !loadAppUi(appId, attemptId))) { val msg =
Application {appId} not found.
res.setStatus(HttpServletResponse.SC_NOT_FOUND) - UIUtils.basicSparkPage(msg, "Not Found").foreach { n => + UIUtils.basicSparkPage(req, msg, "Not Found").foreach { n => res.getWriter().write(n.toString) } return @@ -124,7 +124,7 @@ class HistoryServer( attachHandler(ApiRootResource.getServletHandler(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) val contextHandler = new ServletContextHandler contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) @@ -152,7 +152,6 @@ class HistoryServer( assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") handlers.synchronized { ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 2c78c15773af2..e1184248af460 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -121,10 +121,18 @@ private[deploy] class Master( } // Alternative application submission gateway that is stable across Spark versions - private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) + private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", false) private var restServer: Option[StandaloneRestServer] = None private var restServerBoundPort: Option[Int] = None + { + val authKey = SecurityManager.SPARK_AUTH_SECRET_CONF + require(conf.getOption(authKey).isEmpty || !restServerEnabled, + s"The RestSubmissionServer does not support authentication via ${authKey}. Either turn " + + "off the RestSubmissionServer with spark.master.rest.enabled=false, or do not use " + + "authentication.") + } + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f699c75085fe1..fad4e46dc035d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -40,7 +40,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { val msg =
No running application with ID {appId}
- return UIUtils.basicSparkPage(msg, "Not Found") + return UIUtils.basicSparkPage(request, msg, "Not Found") } val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") @@ -127,7 +127,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") } ; - UIUtils.basicSparkPage(content, "Application: " + app.desc.name) + UIUtils.basicSparkPage(request, content, "Application: " + app.desc.name) } private def executorRow(executor: ExecutorDesc): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index c629937606b51..b8afe203fbfa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -215,7 +215,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) + UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } private def workerRow(worker: WorkerInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 35b7ddd46e4db..e87b2240564bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,7 +43,7 @@ class MasterWebUI( val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(masterPage) - attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 742a95841a138..31a8e3e60c067 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -233,30 +233,44 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { import scala.concurrent.ExecutionContext.Implicits.global val responseFuture = Future { - val dataStream = - if (connection.getResponseCode == HttpServletResponse.SC_OK) { - connection.getInputStream - } else { - connection.getErrorStream + val responseCode = connection.getResponseCode + + if (responseCode != HttpServletResponse.SC_OK) { + val errString = Some(Source.fromInputStream(connection.getErrorStream()) + .getLines().mkString("\n")) + if (responseCode == HttpServletResponse.SC_INTERNAL_SERVER_ERROR && + !connection.getContentType().contains("application/json")) { + throw new SubmitRestProtocolException(s"Server responded with exception:\n${errString}") + } + logError(s"Server responded with error:\n${errString}") + val error = new ErrorResponse + if (responseCode == RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) { + error.highestProtocolVersion = RestSubmissionServer.PROTOCOL_VERSION + } + error.message = errString.get + error + } else { + val dataStream = connection.getInputStream + + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") } - // If the server threw an exception while writing a response, it will not have a body - if (dataStream == null) { - throw new SubmitRestProtocolException("Server returned empty body") - } - val responseJson = Source.fromInputStream(dataStream).mkString - logDebug(s"Response from the server:\n$responseJson") - val response = SubmitRestProtocolMessage.fromJson(responseJson) - response.validate() - response match { - // If the response is an error, log the message - case error: ErrorResponse => - logError(s"Server responded with error:\n${error.message}") - error - // Otherwise, simply return the response - case response: SubmitRestProtocolResponse => response - case unexpected => - throw new SubmitRestProtocolException( - s"Message received from server was not a response:\n${unexpected.toJson}") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index e88195d95f270..e59bf3f0eaf44 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -51,6 +51,7 @@ private[spark] abstract class RestSubmissionServer( val host: String, val requestedPort: Int, val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet protected val killRequestServlet: KillRequestServlet protected val statusRequestServlet: StatusRequestServlet @@ -94,6 +95,7 @@ private[spark] abstract class RestSubmissionServer( new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) + connector.setReuseAddress(!Utils.isWindows) server.addConnector(connector) val mainHandler = new ServletContextHandler diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 5151df00476f9..ab8d8d96a9b08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging * * Also, each HadoopDelegationTokenProvider is controlled by * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to - * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be * enabled/disabled by the configuration spark.security.credentials.hive.enabled. * * @param sparkConf Spark configuration @@ -52,7 +52,7 @@ private[spark] class HadoopDelegationTokenManager( // Maintain all the registered delegation token providers private val delegationTokenProviders = getDelegationTokenProviders - logDebug(s"Using the following delegation token providers: " + + logDebug("Using the following builtin delegation token providers: " + s"${delegationTokenProviders.keys.mkString(", ")}.") /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 58a181128eb4d..a6d13d12fc28d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -225,7 +225,7 @@ private[deploy] class DriverRunner( // check if attempting another run keepTrying = supervise && exitCode != 0 && !killed if (keepTrying) { - if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) { waitSeconds = 1 } logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d4d8521cc8204..dc6a3076a5113 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.Files import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef @@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + val subsOpts = appDesc.command.javaOpts.map { + Utils.substituteAppNExecIds(_, appId, execId.toString) + } + val subsCommand = appDesc.command.copy(javaOpts = subsOpts) + val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 563b84934f264..cbd812a05a2c6 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -23,6 +23,7 @@ import java.text.SimpleDateFormat import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} +import java.util.function.Supplier import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext @@ -49,7 +50,8 @@ private[deploy] class Worker( endpointName: String, workDirPath: String = null, val conf: SparkConf, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + externalShuffleServiceSupplier: Supplier[ExternalShuffleService] = null) extends ThreadSafeRpcEndpoint with Logging { private val host = rpcEnv.address.host @@ -97,6 +99,10 @@ private[deploy] class Worker( private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) + // Whether or not cleanup the non-shuffle files on executor exits. + private val CLEANUP_NON_SHUFFLE_FILES_ENABLED = + conf.getBoolean("spark.storage.cleanupFilesAfterExecutorExit", true) + private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None @@ -142,7 +148,11 @@ private[deploy] class Worker( WorkerWebUI.DEFAULT_RETAINED_DRIVERS) // The shuffle service is not actually started unless configured. - private val shuffleService = new ExternalShuffleService(conf, securityMgr) + private val shuffleService = if (externalShuffleServiceSupplier != null) { + externalShuffleServiceSupplier.get() + } else { + new ExternalShuffleService(conf, securityMgr) + } private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -732,6 +742,9 @@ private[deploy] class Worker( trimFinishedExecutorsIfNecessary() coresUsed -= executor.cores memoryUsed -= executor.memory + if (CLEANUP_NON_SHUFFLE_FILES_ENABLED) { + shuffleService.executorRemoved(executorStateChanged.execId.toString, appId) + } case None => logInfo("Unknown Executor " + fullId + " finished with state " + state + message.map(" message " + _).getOrElse("") + @@ -745,6 +758,7 @@ private[deploy] class Worker( private[deploy] object Worker extends Logging { val SYSTEM_NAME = "sparkWorker" val ENDPOINT_NAME = "Worker" + private val SSL_NODE_LOCAL_CONFIG_PATTERN = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r def main(argStrings: Array[String]) { Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler( @@ -790,9 +804,8 @@ private[deploy] object Worker extends Logging { } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { - val pattern = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r val result = cmd.javaOpts.collectFirst { - case pattern(_result) => _result.toBoolean + case SSL_NODE_LOCAL_CONFIG_PATTERN(_result) => _result.toBoolean } result.getOrElse(false) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 2f5a5642d3cab..4fca9342c0378 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -118,7 +118,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with - UIUtils.basicSparkPage(content, logType + " log page for " + pageName) + UIUtils.basicSparkPage(request, content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 8b98ae56fc108..aa4e28d213e2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -135,7 +135,7 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( + UIUtils.basicSparkPage(request, content, "Spark Worker at %s:%s".format( workerState.host, workerState.port)) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index db696b04384bd..ea67b7434a769 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -47,7 +47,7 @@ class WorkerWebUI( val logPage = new LogPage(this) attachPage(logPage) attachPage(new WorkerPage(this)) - attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) + addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr, diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c325222b764b8..86b19578037df 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -287,6 +287,28 @@ private[spark] class Executor( notifyAll() } + /** + * Utility function to: + * 1. Report executor runtime and JVM gc time if possible + * 2. Collect accumulator updates + * 3. Set the finished flag to true and clear current thread's interrupt status + */ + private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: Long) = { + // Report executor runtime and JVM gc time + Option(task).foreach(t => { + t.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStartTime) + t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + }) + + // Collect latest accumulator values to report back to the driver + val accums: Seq[AccumulatorV2[_, _]] = + Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + + setTaskFinishedAndClearInterruptStatus() + (accums, accUpdates) + } + override def run(): Unit = { threadId = Thread.currentThread.getId Thread.currentThread.setName(threadName) @@ -300,7 +322,7 @@ private[spark] class Executor( val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var taskStart: Long = 0 + var taskStartTime: Long = 0 var taskStartCpu: Long = 0 startGCTime = computeTotalGcTime() @@ -336,19 +358,19 @@ private[spark] class Executor( } // Run the actual task and measure its runtime. - taskStart = System.currentTimeMillis() + taskStartTime = System.currentTimeMillis() taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L var threwException = true - val value = try { + val value = Utils.tryWithSafeFinally { val res = task.run( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false res - } finally { + } { val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() @@ -396,11 +418,11 @@ private[spark] class Executor( // Deserialization happens in two parts: first, we deserialize a Task object, which // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. task.metrics.setExecutorDeserializeTime( - (taskStart - deserializeStartTime) + task.executorDeserializeTime) + (taskStartTime - deserializeStartTime) + task.executorDeserializeTime) task.metrics.setExecutorDeserializeCpuTime( (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) // We need to subtract Task.run()'s deserialization time to avoid double-counting - task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime) task.metrics.setExecutorCpuTime( (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) @@ -482,16 +504,19 @@ private[spark] class Executor( } catch { case t: TaskKilledException => logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case _: InterruptedException | NonFatal(_) if task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason @@ -524,17 +549,7 @@ private[spark] class Executor( // the task failure would not be ignored if the shutdown happened because of premption, // instead of an app issue). if (!ShutdownHookManager.inShutdown()) { - // Collect latest accumulator values to report back to the driver - val accums: Seq[AccumulatorV2[_, _]] = - if (task != null) { - task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) - task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) - task.collectAccumulatorUpdates(taskFailed = true) - } else { - Seq.empty - } - - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) val serializedTaskEndReason = { try { diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 17cdba4f1305b..ab020aaf6fa4f 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -47,7 +47,7 @@ private[spark] abstract class StreamFileInputFormat[T] def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) { val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) - val defaultParallelism = sc.defaultParallelism + val defaultParallelism = Math.max(sc.defaultParallelism, minPartitions) val files = listStatus(context).asScala val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum val bytesPerCore = totalBytes / defaultParallelism diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index f47cd38d712c3..04c5c4b90e8a1 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,6 +53,19 @@ private[spark] class WholeTextFileInputFormat val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong + + // For small files we need to ensure the min split size per node & rack <= maxSplitSize + val config = context.getConfiguration + val minSplitSizePerNode = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERNODE, 0L) + val minSplitSizePerRack = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERRACK, 0L) + + if (maxSplitSize < minSplitSizePerNode) { + super.setMinSplitSizeNode(maxSplitSize) + } + + if (maxSplitSize < minSplitSizePerRack) { + super.setMinSplitSizeRack(maxSplitSize) + } super.setMaxSplitSize(maxSplitSize) } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 99d779fb600e8..7c2f601c9986a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -72,6 +72,9 @@ package object config { private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EVENT_LOG_CALLSITE_FORM = + ConfigBuilder("spark.eventLog.callsite").stringConf.createWithDefault("short") + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional @@ -111,6 +114,10 @@ package object config { .checkValue(_ >= 0, "The off-heap memory size must not be negative") .createWithDefault(0) + private[spark] val PYSPARK_EXECUTOR_MEMORY = ConfigBuilder("spark.executor.pyspark.memory") + .bytesConf(ByteUnit.MiB) + .createOptional + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal() .booleanConf.createWithDefault(false) @@ -126,6 +133,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO = + ConfigBuilder("spark.dynamicAllocation.executorAllocationRatio") + .doubleConf.createWithDefault(1.0) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") @@ -338,7 +349,7 @@ package object config { "a property key or value, the value is redacted from the environment UI and various logs " + "like YARN and event logs.") .regexConf - .createWithDefault("(?i)secret|password|url|user|username".r) + .createWithDefault("(?i)secret|password".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") @@ -348,6 +359,11 @@ package object config { .regexConf .createOptional + private[spark] val AUTH_SECRET_BIT_LENGTH = + ConfigBuilder("spark.authenticate.secretBitLength") + .intConf + .createWithDefault(256) + private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") .booleanConf @@ -420,7 +436,11 @@ package object config { "external shuffle service, this feature can only be worked when external shuffle" + "service is newer than Spark 2.2.") .bytesConf(ByteUnit.BYTE) - .createWithDefault(Long.MaxValue) + // fetch-to-mem is guaranteed to fail if the message is bigger than 2 GB, so we might + // as well use fetch-to-disk in that case. The message includes some metadata in addition + // to the block data itself (in particular UploadBlock has a lot of metadata), so we leave + // extra room. + .createWithDefault(Int.MaxValue - 512) private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") @@ -474,10 +494,11 @@ package object config { private[spark] val FORCE_DOWNLOAD_SCHEMES = ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") - .doc("Comma-separated list of schemes for which files will be downloaded to the " + + .doc("Comma-separated list of schemes for which resources will be downloaded to the " + "local disk prior to being added to YARN's distributed cache. For use in cases " + "where the YARN service does not support schemes that are supported by Spark, like http, " + - "https and ftp.") + "https and ftp, or jars required to be in the local YARN client's classpath. Wildcard " + + "'*' is denoted to download resources for all the schemes.") .stringConf .toSequence .createWithDefault(Nil) @@ -543,4 +564,55 @@ package object config { .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1h") + private[spark] val SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS = + ConfigBuilder("spark.shuffle.minNumPartitionsToHighlyCompress") + .internal() + .doc("Number of partitions to determine if MapStatus should use HighlyCompressedMapStatus") + .intConf + .checkValue(v => v > 0, "The value should be a positive integer.") + .createWithDefault(2000) + + private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS = + ConfigBuilder("spark.storage.memoryMapLimitForTests") + .internal() + .doc("For testing only, controls the size of chunks when memory mapping a file") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Int.MaxValue) + + private[spark] val BARRIER_SYNC_TIMEOUT = + ConfigBuilder("spark.barrier.sync.timeout") + .doc("The timeout in seconds for each barrier() call from a barrier task. If the " + + "coordinator didn't receive all the sync messages from barrier tasks within the " + + "configed time, throw a SparkException to fail all the tasks. The default value is set " + + "to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.") + .timeConf(TimeUnit.SECONDS) + .checkValue(v => v > 0, "The value should be a positive time value.") + .createWithDefaultString("365d") + + private[spark] val BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL = + ConfigBuilder("spark.scheduler.barrier.maxConcurrentTasksCheck.interval") + .doc("Time in seconds to wait between a max concurrent tasks check failure and the next " + + "check. A max concurrent tasks check ensures the cluster can launch more concurrent " + + "tasks than required by a barrier stage on job submitted. The check can fail in case " + + "a cluster has just started and not enough executors have registered, so we wait for a " + + "little while and try to perform the check again. If the check fails more than a " + + "configured max failure times for a job then fail current job submission. Note this " + + "config only applies to jobs that contain one or more barrier stages, we won't perform " + + "the check on non-barrier jobs.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("15s") + + private[spark] val BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES = + ConfigBuilder("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures") + .doc("Number of max concurrent tasks check failures allowed before fail a job submission. " + + "A max concurrent tasks check ensures the cluster can launch more concurrent tasks than " + + "required by a barrier stage on job submitted. The check can fail in case a cluster " + + "has just started and not enough executors have registered, so we wait for a little " + + "while and try to perform the check again. If the check fails more than a configured " + + "max failure times for a job then fail current job submission. Note this config only " + + "applies to jobs that contain one or more barrier stages, we won't perform the check on " + + "non-barrier jobs.") + .intConf + .checkValue(v => v > 0, "The max failures should be a positive value.") + .createWithDefault(40) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala index ddbd624b380d4..af0aa41518766 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala @@ -31,6 +31,8 @@ class HadoopMapRedCommitProtocol(jobId: String, path: String) override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = { val config = context.getConfiguration.asInstanceOf[JobConf] - config.getOutputCommitter + val committer = config.getOutputCommitter + logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}") + committer } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index abf39213fa0d2..9ebd0aa301592 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -76,13 +76,17 @@ object SparkHadoopWriter extends Logging { // Try to write all RDD partitions as a Hadoop OutputFormat. try { val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + // SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers. + // Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently. + val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber + executeTask( context = context, config = config, jobTrackerId = jobTrackerId, commitJobId = commitJobId, sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, + sparkAttemptNumber = attemptId, committer = committer, iterator = iter) }) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 764735dc4eae7..db8aff94ea1e1 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -69,9 +69,9 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val taskAttemptNumber = TaskContext.get().attemptNumber() - val stageId = TaskContext.get().stageId() - val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber) + val ctx = TaskContext.get() + val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(), + splitId, ctx.attemptNumber()) if (canCommit) { performCommit() @@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber) + throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber()) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index b3f8bfe8b1d48..e94a01244474c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.network import scala.reflect.ClassTag import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] @@ -43,6 +44,17 @@ trait BlockDataManager { level: StorageLevel, classTag: ClassTag[_]): Boolean + /** + * Put the given block that will be received as a stream. + * + * When this method is called, the block data itself is not available -- it will be passed to the + * returned StreamCallbackWithID. + */ + def putBlockDataAsStream( + blockId: BlockId, + level: StorageLevel, + classTag: ClassTag[_]): StreamCallbackWithID + /** * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index eb4cf94164fd4..7076701421e2e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -26,9 +26,9 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.NioManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} -import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} +import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -73,10 +73,32 @@ class NettyBlockRpcServer( } val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) val blockId = BlockId(uploadBlock.blockId) + logDebug(s"Receiving replicated block $blockId with level ${level} " + + s"from ${client.getSocketAddress}") blockManager.putBlockData(blockId, data, level, classTag) responseContext.onSuccess(ByteBuffer.allocate(0)) } } + override def receiveStream( + client: TransportClient, + messageHeader: ByteBuffer, + responseContext: RpcResponseCallback): StreamCallbackWithID = { + val message = + BlockTransferMessage.Decoder.fromByteBuffer(messageHeader).asInstanceOf[UploadBlockStream] + val (level: StorageLevel, classTag: ClassTag[_]) = { + serializer + .newInstance() + .deserialize(ByteBuffer.wrap(message.metadata)) + .asInstanceOf[(StorageLevel, ClassTag[_])] + } + val blockId = BlockId(message.blockId) + logDebug(s"Receiving replicated block $blockId with level ${level} as stream " + + s"from ${client.getSocketAddress}") + // This will return immediately, but will setup a callback on streamData which will still + // do all the processing in the netty thread. + blockManager.putBlockDataAsStream(blockId, level, classTag) + } + override def getStreamManager(): StreamManager = streamManager } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7d8c35032763..1905632a936d3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -27,13 +27,14 @@ import scala.reflect.ClassTag import com.codahale.metrics.{Metric, MetricSet} import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.internal.config import org.apache.spark.network._ -import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} -import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -148,20 +149,28 @@ private[spark] class NettyBlockTransferService( // Everything else is encoded using our binary protocol. val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag))) - // Convert or copy nio buffer into array in order to serialize it. - val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) + val asStream = blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + val callback = new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + logTrace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}") + result.success((): Unit) + } - client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, - new RpcResponseCallback { - override def onSuccess(response: ByteBuffer): Unit = { - logTrace(s"Successfully uploaded block $blockId") - result.success((): Unit) - } - override def onFailure(e: Throwable): Unit = { - logError(s"Error while uploading block $blockId", e) - result.failure(e) - } - }) + override def onFailure(e: Throwable): Unit = { + logError(s"Error while uploading $blockId${if (asStream) " as stream" else ""}", e) + result.failure(e) + } + } + if (asStream) { + val streamHeader = new UploadBlockStream(blockId.name, metadata).toByteBuffer + client.uploadStream(new NioManagedBuffer(streamHeader), blockData, callback) + } else { + // Convert or copy nio buffer into array in order to serialize it. + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) + + client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, + callback) + } result.future } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 13db4985b0b80..ba9dae4ad48ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -95,7 +95,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 4e036c2ed49b5..23cf19d55b4ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -30,7 +30,7 @@ private[spark] class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { - @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + @transient lazy val _locations = BlockManager.blockIdsToLocations(blockIds, SparkEnv.get) @volatile private var _isValid = true override def getPartitions: Array[Partition] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 44895abc7bd4d..3974580cfaa11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -278,7 +278,7 @@ class HadoopRDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener { context => + context.addTaskCompletionListener[Unit] { context => // Update the bytes read before closing is to make sure lingering bytesRead statistics in // this thread get correctly added. updateBytesRead() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index aab46b8954bf7..56ef3e107a980 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -77,7 +77,7 @@ class JdbcRDD[T: ClassTag]( override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] { - context.addTaskCompletionListener{ context => closeIfNeeded() } + context.addTaskCompletionListener[Unit]{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index e4587c96eae1c..904d9c025629f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -23,11 +23,21 @@ import org.apache.spark.{Partition, TaskContext} /** * An RDD that applies the provided function to every partition of the parent RDD. + * + * @param prev the parent RDD. + * @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to + * an output iterator. + * @param preservesPartitioning Whether the input function preserves the partitioner, which should + * be `false` unless `prev` is a pair RDD and the input function + * doesn't modify the keys. + * @param isFromBarrier Indicates whether this RDD is transformed from an RDDBarrier, a stage + * containing at least one RDDBarrier shall be turned into a barrier stage. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) - preservesPartitioning: Boolean = false) + preservesPartitioning: Boolean = false, + isFromBarrier: Boolean = false) extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None @@ -41,4 +51,7 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( super.clearDependencies() prev = null } + + @transient protected lazy override val isBarrier_ : Boolean = + isFromBarrier || dependencies.exists(_.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ff66a04859d10..2d66d25ba39fa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -214,7 +214,7 @@ class NewHadoopRDD[K, V]( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener { context => + context.addTaskCompletionListener[Unit] { context => // Update the bytesRead before closing is to make sure lingering bytesRead statistics in // this thread get correctly added. updateBytesRead() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0574abdca32ac..374b846d2ea57 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble @@ -1647,6 +1647,14 @@ abstract class RDD[T: ClassTag]( } } + /** + * :: Experimental :: + * Indicates that Spark must launch the tasks together for the current stage. + */ + @Experimental + @Since("2.4.0") + def barrier(): RDDBarrier[T] = withScope(new RDDBarrier[T](this)) + // ======================================================================= // Other internal methods and fields // ======================================================================= @@ -1839,6 +1847,24 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + + /** + * Whether the RDD is in a barrier stage. Spark must launch all the tasks at the same time for a + * barrier stage. + * + * An RDD is in a barrier stage, if at least one of its parent RDD(s), or itself, are mapped from + * an [[RDDBarrier]]. This function always returns false for a [[ShuffledRDD]], since a + * [[ShuffledRDD]] indicates start of a new stage. + * + * A [[MapPartitionsRDD]] can be transformed from an [[RDDBarrier]], under that case the + * [[MapPartitionsRDD]] shall be marked as barrier. + */ + private[spark] def isBarrier(): Boolean = isBarrier_ + + // From performance concern, cache the value to avoid repeatedly compute `isBarrier()` on a long + // RDD chain. + @transient protected lazy val isBarrier_ : Boolean = + dependencies.filter(!_.isInstanceOf[ShuffleDependency[_, _, _]]).exists(_.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala new file mode 100644 index 0000000000000..b399bf9febae3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.{Experimental, Since} + +/** Represents an RDD barrier, which forces Spark to launch tasks of this stage together. */ +class RDDBarrier[T: ClassTag](rdd: RDD[T]) { + + /** + * :: Experimental :: + * Generate a new barrier RDD by applying a function to each partitions of the prev RDD. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys. + */ + @Experimental + @Since("2.4.0") + def mapPartitions[S: ClassTag]( + f: Iterator[T] => Iterator[S], + preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope { + val cleanedF = rdd.sparkContext.clean(f) + new MapPartitionsRDD( + rdd, + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter), + preservesPartitioning, + isFromBarrier = true + ) + } + + /** TODO extra conf(e.g. timeout) */ +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 979152b55f957..8273d8a9eb476 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -300,7 +300,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { val deserializeStream = serializer.deserializeStream(fileInputStream) // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => deserializeStream.close()) + context.addTaskCompletionListener[Unit](context => deserializeStream.close()) deserializeStream.asIterator.asInstanceOf[Iterator[T]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 26eaa9aa3d03f..e8f9b27b7eb55 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -110,4 +110,6 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( super.clearDependencies() prev = null } + + private[spark] override def isBarrier(): Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 60e383afadf1c..4b6f73235a57a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -20,12 +20,13 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.collection.parallel.ForkJoinTaskSupport +import scala.concurrent.ExecutionContext import scala.concurrent.forkjoin.ForkJoinPool import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.ThreadUtils.parmap import org.apache.spark.util.Utils /** @@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag]( } object UnionRDD { - private[spark] lazy val partitionEvalTaskSupport = - new ForkJoinTaskSupport(new ForkJoinPool(8)) + private[spark] lazy val threadPool = new ForkJoinPool(8) } @DeveloperApi @@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag]( rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10) override def getPartitions: Array[Partition] = { - val parRDDs = if (isPartitionListingParallel) { - val parArray = rdds.par - parArray.tasksupport = UnionRDD.partitionEvalTaskSupport - parArray + val partitionLengths = if (isPartitionListingParallel) { + implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool) + parmap(rdds)(_.partitions.length) } else { - rdds + rdds.map(_.partitions.length) } - val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum) + val array = new Array[Partition](partitionLengths.sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a2936d6ad539c..47576959322d1 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -50,7 +50,7 @@ private[netty] class NettyRpcEnv( private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", - conf.getInt("spark.rpc.io.threads", 0)) + conf.getInt("spark.rpc.io.threads", numUsableCores)) private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 949e88f606275..6e4d062749d5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -60,4 +60,10 @@ private[spark] class ActiveJob( val finished = Array.fill[Boolean](numPartitions)(false) var numFinished = 0 + + /** Resets the status of all partitions in this stage so they are marked as not finished. */ + def resetAllPartitions(): Unit = { + (0 until numPartitions).foreach(finished.update(_, false)) + numFinished = 0 + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index c1fedd63f6a90..e2b6df4600590 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -34,7 +34,11 @@ import org.apache.spark.util.Utils * Delivery will only begin when the `start()` method is called. The `stop()` method should be * called when no more events need to be delivered. */ -private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) +private class AsyncEventQueue( + val name: String, + conf: SparkConf, + metrics: LiveListenerBusMetrics, + bus: LiveListenerBus) extends SparkListenerBus with Logging { @@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi } private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { - try { - var next: SparkListenerEvent = eventQueue.take() - while (next != POISON_PILL) { - val ctx = processingTime.time() - try { - super.postToAll(next) - } finally { - ctx.stop() - } - eventCount.decrementAndGet() - next = eventQueue.take() + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() } eventCount.decrementAndGet() - } catch { - case ie: InterruptedException => - logInfo(s"Stopping listener queue $name.", ie) + next = eventQueue.take() } + eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { @@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi eventCount.incrementAndGet() eventQueue.put(POISON_PILL) } - dispatchThread.join() + // this thread might be trying to stop itself as part of error handling -- we can't join + // in that case. + if (Thread.currentThread() != dispatchThread) { + dispatchThread.join() + } } def post(event: SparkListenerEvent): Unit = { @@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi true } + override def removeListenerOnError(listener: SparkListenerInterface): Unit = { + // the listener failed in an unrecoverably way, we want to remove it from the entire + // LiveListenerBus (potentially stopping a queue if it is empty) + bus.removeListener(listener) + } + } private object AsyncEventQueue { diff --git a/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala b/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala new file mode 100644 index 0000000000000..803a0a1226d6c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.SparkException + +/** + * Exception thrown when submit a job with barrier stage(s) failing a required check. + */ +private[spark] class BarrierJobAllocationFailed(message: String) extends SparkException(message) + +private[spark] class BarrierJobUnsupportedRDDChainException + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + +private[spark] class BarrierJobRunWithDynamicAllocationException + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + +private[spark] class BarrierJobSlotsNumberCheckFailed + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + +private[spark] object BarrierJobAllocationFailed { + + // Error message when running a barrier stage that have unsupported RDD chain pattern. + val ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN = + "[SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following pattern of " + + "RDD chain within a barrier stage:\n1. Ancestor RDDs that have different number of " + + "partitions from the resulting RDD (eg. union()/coalesce()/first()/take()/" + + "PartitionPruningRDD). A workaround for first()/take() can be barrierRdd.collect().head " + + "(scala) or barrierRdd.collect()[0] (python).\n" + + "2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2))." + + // Error message when running a barrier stage with dynamic resource allocation enabled. + val ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION = + "[SPARK-24942]: Barrier execution mode does not support dynamic resource allocation for " + + "now. You can disable dynamic resource allocation by setting Spark conf " + + "\"spark.dynamicAllocation.enabled\" to \"false\"." + + // Error message when running a barrier stage that requires more slots than current total number. + val ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER = + "[SPARK-24819]: Barrier execution mode does not allow run a barrier stage that requires " + + "more slots than the total number of slots in the cluster currently. Please init a new " + + "cluster with more CPU cores or repartition the input RDD(s) to reduce the number of " + + "slots required to run this barrier stage." +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 30cf75d43ee09..980fbbe516b91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -371,7 +371,7 @@ private[scheduler] class BlacklistTracker ( } -private[scheduler] object BlacklistTracker extends Logging { +private[spark] object BlacklistTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 78b6b34b5d2bb..6787250ddc3f4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,8 +19,9 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicInteger +import java.util.function.BiFunction import scala.annotation.tailrec import scala.collection.Map @@ -111,8 +112,7 @@ import org.apache.spark.util._ * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ -private[spark] -class DAGScheduler( +private[spark] class DAGScheduler( private[scheduler] val sc: SparkContext, private[scheduler] val taskScheduler: TaskScheduler, listenerBus: LiveListenerBus, @@ -203,10 +203,28 @@ class DAGScheduler( sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + /** + * Number of max concurrent tasks check failures for each barrier job. + */ + private[scheduler] val barrierJobIdToNumTasksCheckFailures = new ConcurrentHashMap[Int, Int] + + /** + * Time in seconds to wait between a max concurrent tasks check failure and the next check. + */ + private val timeIntervalNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) + + /** + * Max number of max concurrent tasks check failures allowed for a job before fail the job + * submission. + */ + private val maxFailureNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) + private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") - private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) /** @@ -340,6 +358,21 @@ class DAGScheduler( } } + /** + * Check to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The + * following patterns are not supported: + * 1. Ancestor RDDs that have different number of partitions from the resulting RDD (eg. + * union()/coalesce()/first()/take()/PartitionPruningRDD); + * 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)). + */ + private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numTasksInStage: Int): Unit = { + val predicate: RDD[_] => Boolean = (r => + r.getNumPartitions == numTasksInStage && r.dependencies.filter(_.rdd.isBarrier()).size <= 1) + if (rdd.isBarrier() && !traverseParentRDDsWithinStage(rdd, predicate)) { + throw new BarrierJobUnsupportedRDDChainException + } + } + /** * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a * previously run stage generated the same shuffle data, this function will copy the output @@ -348,6 +381,9 @@ class DAGScheduler( */ def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd + checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd) + checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions) val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() @@ -367,6 +403,36 @@ class DAGScheduler( stage } + /** + * We don't support run a barrier stage with dynamic resource allocation enabled, it shall lead + * to some confusing behaviors (eg. with dynamic resource allocation enabled, it may happen that + * we acquire some executors (but not enough to launch all the tasks in a barrier stage) and + * later release them due to executor idle time expire, and then acquire again). + * + * We perform the check on job submit and fail fast if running a barrier stage with dynamic + * resource allocation enabled. + * + * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage + */ + private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { + if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { + throw new BarrierJobRunWithDynamicAllocationException + } + } + + /** + * Check whether the barrier stage requires more slots (to be able to launch all tasks in the + * barrier stage together) than the total number of active slots currently. Fail current check + * if trying to submit a barrier stage that requires more slots than current total number. If + * the check fails consecutively beyond a configured number for a job, then fail current job + * submission. + */ + private def checkBarrierStageWithNumSlots(rdd: RDD[_]): Unit = { + if (rdd.isBarrier() && rdd.getNumPartitions > sc.maxNumConcurrentTasks) { + throw new BarrierJobSlotsNumberCheckFailed + } + } + /** * Create a ResultStage associated with the provided jobId. */ @@ -376,6 +442,9 @@ class DAGScheduler( partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { + checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd) + checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite) @@ -451,6 +520,32 @@ class DAGScheduler( parents } + /** + * Traverses the given RDD and its ancestors within the same stage and checks whether all of the + * RDDs satisfy a given predicate. + */ + private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = { + val visited = new HashSet[RDD[_]] + val waitingForVisit = new ArrayStack[RDD[_]] + waitingForVisit.push(rdd) + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.pop() + if (!visited(toVisit)) { + if (!predicate(toVisit)) { + return false + } + visited += toVisit + toVisit.dependencies.foreach { + case _: ShuffleDependency[_, _, _] => + // Not within the same stage with current rdd, do nothing. + case dependency => + waitingForVisit.push(dependency.rdd) + } + } + } + true + } + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] @@ -866,11 +961,38 @@ class DAGScheduler( // HadoopRDD whose underlying HDFS files have been deleted. finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite) } catch { + case e: BarrierJobSlotsNumberCheckFailed => + logWarning(s"The job $jobId requires to run a barrier stage that requires more slots " + + "than the total number of slots in the cluster currently.") + // If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically. + val numCheckFailures = barrierJobIdToNumTasksCheckFailures.compute(jobId, + new BiFunction[Int, Int, Int] { + override def apply(key: Int, value: Int): Int = value + 1 + }) + if (numCheckFailures <= maxFailureNumTasksCheck) { + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(JobSubmitted(jobId, finalRDD, func, + partitions, callSite, listener, properties)) + }, + timeIntervalNumTasksCheck, + TimeUnit.SECONDS + ) + return + } else { + // Job failed, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) + listener.jobFailed(e) + return + } + case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } + // Job submitted, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) clearCacheLocs() @@ -1062,7 +1184,7 @@ class DAGScheduler( stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), - Option(sc.applicationId), sc.applicationAttemptId) + Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) } case stage: ResultStage => @@ -1072,7 +1194,8 @@ class DAGScheduler( val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, - Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, + stage.rdd.isBarrier()) } } } catch { @@ -1167,12 +1290,11 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task - val taskId = event.taskInfo.id val stageId = task.stageId - val taskType = Utils.getFormattedClassName(task) outputCommitCoordinator.taskCompleted( stageId, + task.stageAttemptId, task.partitionId, event.taskInfo.attemptNumber, // this is a task attempt number event.reason) @@ -1210,7 +1332,7 @@ class DAGScheduler( case _ => updateAccumulators(event) } - case _: ExceptionFailure => updateAccumulators(event) + case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event) case _ => } postTaskEnd(event) @@ -1312,18 +1434,7 @@ class DAGScheduler( } } - case Resubmitted => - logInfo("Resubmitted " + task + ", so marking it as still running") - stage match { - case sms: ShuffleMapStage => - sms.pendingPartitions += task.partitionId - - case _ => - assert(false, "TaskSetManagers should only send Resubmitted task statuses for " + - "tasks in ShuffleMapStages.") - } - - case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => + case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1332,22 +1443,48 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) + markStageAsFinished(failedStage, errorMessage = Some(failureMessage), + willRetry = !shouldAbortStage) } else { logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + s"longer running") } - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest + if (mapStage.rdd.isBarrier()) { + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(shuffleId) + } else if (mapId != -1) { + // Mark the map whose fetch failed as broken in the map stage + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + + if (failedStage.rdd.isBarrier()) { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $failureMessage" + abortStage(failedResultStage, reason, None) + } + } if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { @@ -1375,7 +1512,7 @@ class DAGScheduler( // simpler while not producing an overwhelming number of scheduler events. logInfo( s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure" + s"$failedStage (${failedStage.name}) due to fetch failure" ) messageScheduler.schedule( new Runnable { @@ -1386,10 +1523,6 @@ class DAGScheduler( ) } } - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { @@ -1411,21 +1544,118 @@ class DAGScheduler( } } - case commitDenied: TaskCommitDenied => + case failure: TaskFailedReason if task.isBarrier => + // Also handle the task failed reasons here. + failure match { + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _ => // Do nothing. + } + + // Always fail the current stage and retry all the tasks when a barrier task fail. + val failedStage = stageIdToStage(task.stageId) + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring task failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") + } else { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + + "failed.") + val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + + failure.toErrorString + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. + val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) " + + "failed." + taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason) + } catch { + case e: UnsupportedOperationException => + // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. + // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. + logWarning(s"Could not kill all tasks for stage $stageId", e) + abortStage(failedStage, "Could not kill zombie barrier tasks for stage " + + s"$failedStage (${failedStage.name})", Some(e)) + } + markStageAsFinished(failedStage, Some(message)) + + failedStage.failedAttemptIds.add(task.stageAttemptId) + // TODO Refactor the failure handling logic to combine similar code with that of + // FetchFailed. + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Barrier stage will not retry stage due to testing config. Most recent failure " + + s"reason: $message" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $message + """.stripMargin.replaceAll("\n", " ") + } + abortStage(failedStage, abortMessage, None) + } else { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $message" + abortStage(failedResultStage, reason, None) + } + // In case multiple task failures triggered for a single stage attempt, ensure we only + // resubmit the failed stage once. + val noResubmitEnqueued = !failedStages.contains(failedStage) + failedStages += failedStage + if (noResubmitEnqueued) { + logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + + "failure.") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + } + } + + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case exceptionFailure: ExceptionFailure => + case _: ExceptionFailure | _: TaskKilled => // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => + case _: ExecutorLostFailure | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } } + private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = { + logInfo(s"Resubmitted $task, so marking it as still running.") + stage match { + case sms: ShuffleMapStage => + sms.pendingPartitions += task.partitionId + + case _ => + throw new SparkException("TaskSetManagers should only send Resubmitted task " + + "statuses for tasks in ShuffleMapStages.") + } + } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { // Mark any map-stage jobs waiting on this stage as finished if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { @@ -1547,7 +1777,10 @@ class DAGScheduler( /** * Marks a stage as finished and removes it from the list of running stages. */ - private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + private def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" @@ -1566,7 +1799,9 @@ class DAGScheduler( logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") } - outputCommitCoordinator.stageEnd(stage.id) + if (!willRetry) { + outputCommitCoordinator.stageEnd(stage.id) + } listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index ba6387a8f08ad..d135190d1e919 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) { queue.addListener(listener) case None => - val newQueue = new AsyncEventQueue(queue, conf, metrics) + val newQueue = new AsyncEventQueue(queue, conf, metrics, this) newQueue.addListener(listener) if (started.get()) { newQueue.start(sparkContext) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 2ec2f2031aa45..7e1d75fe723d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -31,7 +31,8 @@ import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. + * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task, + * for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { /** Location where this task was run. */ @@ -44,16 +45,23 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long + + /** + * The number of outputs for the map task. + */ + def numberOfOutput: Long } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > 2000) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = { + if (uncompressedSizes.length > Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) + .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { + HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, numOutput) } } @@ -96,29 +104,34 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) } override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -141,17 +154,20 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte]) + private var hugeBlockSizes: Map[Int, Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -166,6 +182,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -177,6 +194,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -192,7 +210,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + numOutput: Long): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -233,6 +254,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap) + hugeBlockSizesArray.toMap, numOutput) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 83d87b548a430..b382d623806e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) +private case class AskPermissionToCommitOutput( + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None - private type StageId = Int - private type PartitionId = Int - private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + // Class used to identify a committer. The task ID for a committer is implicitly defined by + // the partition being processed, but the coordinator needs to keep track of both the stage + // attempt and the task attempt, because in some situations the same task may be running + // concurrently in two different attempts of the same stage. + private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int) + private case class StageState(numPartitions: Int) { - val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) - val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null) + val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]() } /** @@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val stageStates = mutable.Map[StageId, StageState]() + private val stageStates = mutable.Map[Int, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @return true if this task is authorized to commit, false otherwise */ def canCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = { + val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg), @@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } /** - * Called by the DAGScheduler when a stage starts. + * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't + * yet been initialized. * * @param stage the stage id. * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { - stageStates(stage) = new StageState(maxPartitionId + 1) + private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized { + stageStates.get(stage) match { + case Some(state) => + require(state.authorizedCommitters.length == maxPartitionId + 1) + logInfo(s"Reusing state from previous attempt of stage $stage.") + + case _ => + stageStates(stage) = new StageState(maxPartitionId + 1) + } } // Called by DAGScheduler - private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + private[scheduler] def stageEnd(stage: Int): Unit = synchronized { stageStates.remove(stage) } // Called by DAGScheduler private[scheduler] def taskCompleted( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber, + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int, reason: TaskEndReason): Unit = synchronized { val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) reason match { case Success => // The task output has been committed successfully - case denied: TaskCommitDenied => - logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + - s"attempt: $attemptNumber") - case otherReason => + case _: TaskCommitDenied => + logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " + + s"partition: $partition, attempt: $attemptNumber") + case _ => // Mark the attempt as failed to blacklist from future commit protocol - stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber - if (stageState.authorizedCommitters(partition) == attemptNumber) { + val taskId = TaskIdentifier(stageAttempt, attemptNumber) + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId + if (stageState.authorizedCommitters(partition) == taskId) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = null } } } @@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Marked private[scheduler] instead of private so this can be mocked in tests private[scheduler] def handleAskPermissionToCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = synchronized { + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = synchronized { stageStates.get(stage) match { - case Some(state) if attemptFailed(state, partition, attemptNumber) => - logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition as task attempt $attemptNumber has already failed.") + case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) => + logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"task attempt $attemptNumber already marked as failed.") false case Some(state) => - state.authorizedCommitters(partition) match { - case NO_AUTHORIZED_COMMITTER => - logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition") - state.authorizedCommitters(partition) = attemptNumber - true - case existingCommitter => - // Coordinator should be idempotent when receiving AskPermissionToCommit. - if (existingCommitter == attemptNumber) { - logWarning(s"Authorizing duplicate request to commit for " + - s"attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition; existingCommitter = $existingCommitter." + - s" This can indicate dropped network traffic.") - true - } else { - logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition; existingCommitter = $existingCommitter") - false - } + val existing = state.authorizedCommitters(partition) + if (existing == null) { + logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " + + s"task attempt $attemptNumber") + state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber) + true + } else { + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"already committed by $existing") + false } case None => - logDebug(s"Stage $stage has completed, so not allowing" + - s" attempt number $attemptNumber of partition $partition to commit") + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + "stage already marked as completed.") false } } private def attemptFailed( stageState: StageState, - partition: PartitionId, - attempt: TaskAttemptNumber): Boolean = synchronized { - stageState.failures.get(partition).exists(_.contains(attempt)) + stageAttempt: Int, + partition: Int, + attempt: Int): Boolean = synchronized { + val failInfo = TaskIdentifier(stageAttempt, attempt) + stageState.failures.get(partition).exists(_.contains(failInfo)) } } @@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, attemptNumber) => + case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition, + attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index e36c759a42556..aafeae05b566c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -48,7 +48,9 @@ import org.apache.spark.rdd.RDD * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to - */ + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. + */ private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, @@ -60,9 +62,10 @@ private[spark] class ResultTask[T, U]( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics, - jobId, appId, appAttemptId) + jobId, appId, appAttemptId, isBarrier) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 22db3350abfa7..c187ee146301b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -69,4 +69,13 @@ private[spark] trait SchedulerBackend { */ def getDriverLogUrls: Option[Map[String, String]] = None + /** + * Get the max number of tasks that can be concurrent launched currently. + * Note that please don't cache the value returned by this method, because the number can change + * due to add/remove executors. + * + * @return The max number of tasks that can be concurrent launched currently. + */ + def maxNumConcurrentTasks(): Int + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 7a25c47e2cab3..f2cd65fd523ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -49,6 +49,8 @@ import org.apache.spark.shuffle.ShuffleWriter * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -60,9 +62,10 @@ private[spark] class ShuffleMapTask( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties, - serializedTaskMetrics, jobId, appId, appAttemptId) + serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 290fd073caf27..26cca334d3bd5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -82,15 +82,15 @@ private[scheduler] abstract class Stage( private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) /** - * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these - * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid + * endless retries if a stage keeps failing. * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - val fetchFailedAttemptIds = new HashSet[Int] + val failedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { - fetchFailedAttemptIds.clear() + failedAttemptIds.clear() } /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index f536fc2a5f0a1..11f85fd91ba08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -49,6 +49,8 @@ import org.apache.spark.util._ * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. */ private[spark] abstract class Task[T]( val stageId: Int, @@ -60,7 +62,8 @@ private[spark] abstract class Task[T]( SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), val jobId: Option[Int] = None, val appId: Option[String] = None, - val appAttemptId: Option[String] = None) extends Serializable { + val appAttemptId: Option[String] = None, + val isBarrier: Boolean = false) extends Serializable { @transient lazy val metrics: TaskMetrics = SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics)) @@ -77,16 +80,32 @@ private[spark] abstract class Task[T]( attemptNumber: Int, metricsSystem: MetricsSystem): T = { SparkEnv.get.blockManager.registerTask(taskAttemptId) - context = new TaskContextImpl( - stageId, - stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal - partitionId, - taskAttemptId, - attemptNumber, - taskMemoryManager, - localProperties, - metricsSystem, - metrics) + // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether + // the stage is barrier. + context = if (isBarrier) { + new BarrierTaskContext( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + } else { + new TaskContextImpl( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + } + TaskContext.setTaskContext(context) taskThread = Thread.currentThread() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index c98b87148e404..bb4a4442b9433 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -50,6 +50,7 @@ private[spark] class TaskDescription( val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet + val partitionId: Int, val addedFiles: Map[String, Long], val addedJars: Map[String, Long], val properties: Properties, @@ -76,6 +77,7 @@ private[spark] object TaskDescription { dataOut.writeUTF(taskDescription.executorId) dataOut.writeUTF(taskDescription.name) dataOut.writeInt(taskDescription.index) + dataOut.writeInt(taskDescription.partitionId) // Write files. serializeStringLongMap(taskDescription.addedFiles, dataOut) @@ -117,6 +119,7 @@ private[spark] object TaskDescription { val executorId = dataIn.readUTF() val name = dataIn.readUTF() val index = dataIn.readInt() + val partitionId = dataIn.readInt() // Read files. val taskFiles = deserializeStringLongMap(dataIn) @@ -138,7 +141,7 @@ private[spark] object TaskDescription { // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). val serializedTask = byteBuffer.slice() - new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, - properties, serializedTask) + new TaskDescription(taskId, attemptNumber, executorId, name, index, partitionId, taskFiles, + taskJars, properties, serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 90644fea23ab1..95f7ae4fd39a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -51,16 +51,22 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit - // Cancel a stage. + // Kill all the tasks in a stage and fail the stage and all the jobs that depend on the stage. + // Throw UnsupportedOperationException if the backend doesn't support kill tasks. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit /** * Kills a task attempt. + * Throw UnsupportedOperationException if the backend doesn't support kill a task. * * @return Whether the task was successfully killed. */ def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Kill all the running task attempts in a stage. + // Throw UnsupportedOperationException if the backend doesn't support kill tasks. + def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c11806b3981b..8992d7e2284a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -30,6 +30,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging import org.apache.spark.internal.config +import org.apache.spark.rpc.RpcEndpoint import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.BlockManagerId @@ -42,7 +43,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * up to launch speculative tasks, etc. * * Clients should first call initialize() and start(), then submit task sets through the - * runTasks method. + * submitTasks method. * * THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some @@ -62,7 +63,7 @@ private[spark] class TaskSchedulerImpl( this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // Lazily initializing blacklistTrackerOpt to avoid getting empty ExecutorAllocationClient, // because ExecutorAllocationClient is created after this TaskSchedulerImpl. private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) @@ -138,6 +139,19 @@ private[spark] class TaskSchedulerImpl( // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + private lazy val barrierSyncTimeout = conf.get(config.BARRIER_SYNC_TIMEOUT) + + private[scheduler] var barrierCoordinator: RpcEndpoint = null + + private def maybeInitBarrierCoordinator(): Unit = { + if (barrierCoordinator == null) { + barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, + sc.env.rpcEnv) + sc.env.rpcEnv.setupEndpoint("barrierSync", barrierCoordinator) + logInfo("Registered BarrierCoordinator endpoint") + } + } + override def setDAGScheduler(dagScheduler: DAGScheduler) { this.dagScheduler = dagScheduler } @@ -222,18 +236,11 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) + // Kill all running tasks for the stage. + killAllTaskAttempts(stageId, interruptThread, reason = "Stage cancelled") + // Cancel all attempts for the stage. taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => attempts.foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - taskIdToExecutorId.get(tid).foreach(execId => - backend.killTask(tid, execId, interruptThread, reason = "Stage cancelled")) - } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) } @@ -252,6 +259,27 @@ private[spark] class TaskSchedulerImpl( } } + override def killAllTaskAttempts( + stageId: Int, + interruptThread: Boolean, + reason: String): Unit = synchronized { + logInfo(s"Killing all running tasks in stage $stageId: $reason") + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task. + // 2. The task set manager has been created but no tasks have been scheduled. In this case, + // simply continue. + tsm.runningTasksSet.foreach { tid => + taskIdToExecutorId.get(tid).foreach { execId => + backend.killTask(tid, execId, interruptThread, reason) + } + } + } + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -274,7 +302,8 @@ private[spark] class TaskSchedulerImpl( maxLocality: TaskLocality, shuffledOffers: Seq[WorkerOffer], availableCpus: Array[Int], - tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { + tasks: IndexedSeq[ArrayBuffer[TaskDescription]], + addressesWithDescs: ArrayBuffer[(String, TaskDescription)]) : Boolean = { var launchedTask = false // nodes and executors that are blacklisted for the entire application have already been // filtered out by this point @@ -291,6 +320,11 @@ private[spark] class TaskSchedulerImpl( executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) + // Only update hosts for a barrier task. + if (taskSet.isBarrier) { + // The executor address is expected to be non empty. + addressesWithDescs += (shuffledOffers(i).address.get -> task) + } launchedTask = true } } catch { @@ -346,6 +380,7 @@ private[spark] class TaskSchedulerImpl( // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableCpus = shuffledOffers.map(o => o.cores).toArray + val availableSlots = shuffledOffers.map(o => o.cores / CPUS_PER_TASK).sum val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( @@ -359,20 +394,58 @@ private[spark] class TaskSchedulerImpl( // of locality levels so that it gets a chance to launch local tasks on all of them. // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY for (taskSet <- sortedTaskSets) { - var launchedAnyTask = false - var launchedTaskAtCurrentMaxLocality = false - for (currentMaxLocality <- taskSet.myLocalityLevels) { - do { - launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet( - taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks) - launchedAnyTask |= launchedTaskAtCurrentMaxLocality - } while (launchedTaskAtCurrentMaxLocality) - } - if (!launchedAnyTask) { - taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + // Skip the barrier taskSet if the available slots are less than the number of pending tasks. + if (taskSet.isBarrier && availableSlots < taskSet.numTasks) { + // Skip the launch process. + // TODO SPARK-24819 If the job requires more slots than available (both busy and free + // slots), fail the job on submit. + logInfo(s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + + s"because the barrier taskSet requires ${taskSet.numTasks} slots, while the total " + + s"number of available slots is $availableSlots.") + } else { + var launchedAnyTask = false + // Record all the executor IDs assigned barrier tasks on. + val addressesWithDescs = ArrayBuffer[(String, TaskDescription)]() + for (currentMaxLocality <- taskSet.myLocalityLevels) { + var launchedTaskAtCurrentMaxLocality = false + do { + launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(taskSet, + currentMaxLocality, shuffledOffers, availableCpus, tasks, addressesWithDescs) + launchedAnyTask |= launchedTaskAtCurrentMaxLocality + } while (launchedTaskAtCurrentMaxLocality) + } + if (!launchedAnyTask) { + taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + } + if (launchedAnyTask && taskSet.isBarrier) { + // Check whether the barrier tasks are partially launched. + // TODO SPARK-24818 handle the assert failure case (that can happen when some locality + // requirements are not fulfilled, and we should revert the launched tasks). + require(addressesWithDescs.size == taskSet.numTasks, + s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + + s"because only ${addressesWithDescs.size} out of a total number of " + + s"${taskSet.numTasks} tasks got resource offers. The resource offers may have " + + "been blacklisted or cannot fulfill task locality requirements.") + + // materialize the barrier coordinator. + maybeInitBarrierCoordinator() + + // Update the taskInfos into all the barrier task properties. + val addressesStr = addressesWithDescs + // Addresses ordered by partitionId + .sortBy(_._2.partitionId) + .map(_._1) + .mkString(",") + addressesWithDescs.foreach(_._2.properties.setProperty("addresses", addressesStr)) + + logInfo(s"Successfully scheduled all the ${addressesWithDescs.size} tasks for barrier " + + s"stage ${taskSet.stageId}.") + } } } + // TODO SPARK-24823 Cancel a job that contains barrier stage(s) if the barrier tasks don't get + // launched within a configured time. if (tasks.size > 0) { hasLaunchedTask = true } @@ -510,6 +583,9 @@ private[spark] class TaskSchedulerImpl( if (taskResultGetter != null) { taskResultGetter.stop() } + if (barrierCoordinator != null) { + barrierCoordinator.stop() + } starvationTimer.cancel() } @@ -689,6 +765,23 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in all TaskSetManagers for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier attempt of a stage completes a task, we should ensure that the later attempts + * do not also submit those same tasks. That also means that a task completion from an earlier + * attempt can lead to the entire stage getting marked as successful. + */ + private[scheduler] def markPartitionCompletedInAllTaskSets( + stageId: Int, + partitionId: Int, + taskInfo: TaskInfo) = { + taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => + tsm.markPartitionCompleted(partitionId, taskInfo) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d958658527f6d..d5e85a11cb279 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.SchedulingMode._ -import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} +import org.apache.spark.util.{AccumulatorV2, Clock, LongAccumulator, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap /** @@ -73,6 +73,8 @@ private[spark] class TaskSetManager( val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks + private[scheduler] val partitionToIndex = tasks.zipWithIndex + .map { case (t, idx) => t.partitionId -> idx }.toMap val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -82,10 +84,10 @@ private[spark] class TaskSetManager( val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) - // Set the coresponding index of Boolean var when the task killed by other attempt tasks, - // this happened while we set the `spark.speculation` to true. The task killed by others + // Add the tid of task into this HashSet when the task is killed by other attempt tasks. + // This happened while we set the `spark.speculation` to true. The task killed by others // should not resubmit while executor lost. - private val killedByOtherAttempt: Array[Boolean] = new Array[Boolean](numTasks) + private val killedByOtherAttempt = new HashSet[Long] val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 @@ -121,6 +123,10 @@ private[spark] class TaskSetManager( // TODO: We should kill any running task attempts when the task set manager becomes a zombie. private[scheduler] var isZombie = false + // Whether the taskSet run tasks from a barrier stage. Spark must launch all the tasks at the + // same time for a barrier stage. + private[scheduler] def isBarrier = taskSet.tasks.nonEmpty && taskSet.tasks(0).isBarrier + // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect @@ -153,7 +159,7 @@ private[spark] class TaskSetManager( private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - private val taskInfos = new HashMap[Long, TaskInfo] + private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead @@ -287,7 +293,7 @@ private[spark] class TaskSetManager( None } - /** Check whether a task is currently running an attempt on a given host */ + /** Check whether a task once ran an attempt on a given host */ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { taskAttempts(taskIndex).exists(_.host == host) } @@ -510,6 +516,7 @@ private[spark] class TaskSetManager( execId, taskName, index, + task.partitionId, addedFiles, addedJars, task.localProperties, @@ -721,6 +728,23 @@ private[spark] class TaskSetManager( def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index + // Check if any other attempt succeeded before this and this attempt has not been handled + if (successful(index) && killedByOtherAttempt.contains(tid)) { + // Undo the effect on calculatedTasks and totalResultSize made earlier when + // checking if can fetch more results + calculatedTasks -= 1 + val resultSizeAcc = result.accumUpdates.find(a => + a.name == Some(InternalAccumulator.RESULT_SIZE)) + if (resultSizeAcc.isDefined) { + totalResultSize -= resultSizeAcc.get.asInstanceOf[LongAccumulator].value + } + + // Handle this task as a killed task + handleFailedTask(tid, TaskState.KILLED, + TaskKilled("Finish but did not commit due to another attempt succeeded")) + return + } + info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) if (speculationEnabled) { successfulTaskDurations.insert(info.duration) @@ -733,7 +757,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - killedByOtherAttempt(index) = true + killedByOtherAttempt += attemptInfo.taskId sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -754,6 +778,9 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // There may be multiple tasksets for this stage -- we let all of them know that the partition + // was completed. This may result in some of the tasksets getting completed. + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -764,6 +791,22 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = { + partitionToIndex.get(partitionId).foreach { index => + if (!successful(index)) { + if (speculationEnabled && !isZombie) { + successfulTaskDurations.insert(taskInfo.duration) + } + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + isZombie = true + } + maybeFinishTaskSet() + } + } + } + /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. @@ -833,17 +876,27 @@ private[spark] class TaskSetManager( } ef.exception + case tk: TaskKilled => + // TaskKilled might have accumulator updates + accumUpdates = tk.accums + logWarning(failureReason) + None + case e: ExecutorLostFailure if !e.exitCausedByApp => logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None - case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others + case e: TaskFailedReason => // TaskResultLost and others logWarning(failureReason) None } + if (tasks(index).isBarrier) { + isZombie = true + } + sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (!isZombie && reason.countTowardsTaskFailures) { @@ -920,7 +973,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index) && !killedByOtherAttempt(index)) { + if (successful(index) && !killedByOtherAttempt.contains(tid)) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 @@ -952,8 +1005,8 @@ private[spark] class TaskSetManager( */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a - // zombie. - if (isZombie || numTasks == 1) { + // zombie or is from a barrier stage. + if (isZombie || isBarrier || numTasks == 1) { return false } var foundTasks = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index 810b36cddf835..6ec74913e42f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -21,4 +21,10 @@ package org.apache.spark.scheduler * Represents free resources available on an executor. */ private[spark] -case class WorkerOffer(executorId: String, host: String, cores: Int) +case class WorkerOffer( + executorId: String, + host: String, + cores: Int, + // `address` is an optional hostPort string, it provide more useful information than `host` + // when multiple executors are launched on the same host. + address: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5627a557a12f3..747e8c7dc0fa5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -170,8 +170,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) - } else if (scheduler.nodeBlacklist != null && - scheduler.nodeBlacklist.contains(hostname)) { + } else if (scheduler.nodeBlacklist.contains(hostname)) { // If the cluster manager gives us an executor on a blacklisted node (because it // already started allocating those resources before we informed it of our blacklist, // or if it ignored our blacklist), then we reject that executor immediately. @@ -243,7 +242,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + new WorkerOffer(id, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort)) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -268,7 +268,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) val workOffers = IndexedSeq( - new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort))) scheduler.resourceOffers(workOffers) } else { Seq.empty @@ -495,6 +496,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorDataMap.keySet.toSeq } + override def maxNumConcurrentTasks(): Int = { + executorDataMap.values.map { executor => + executor.totalCores / scheduler.CPUS_PER_TASK + }.sum + } + /** * Request an additional number of executors from the cluster manager. * @return whether the request is acknowledged. @@ -633,7 +640,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } doRequestTotalExecutors(requestedTotalExecutors) } else { - numPendingExecutors += knownExecutors.size + numPendingExecutors += executorsToKill.size Future.successful(true) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 4c614c5c0f602..0de57fbd5600c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -81,7 +81,8 @@ private[spark] class LocalEndpoint( } def reviveOffers() { - val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores, + Some(rpcEnv.address.hostPort))) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, task) @@ -155,6 +156,8 @@ private[spark] class LocalSchedulerBackend( override def applicationId(): String = appId + override def maxNumConcurrentTasks(): Int = totalCores / scheduler.CPUS_PER_TASK + private def stop(finalState: SparkAppHandle.State): Unit = { localEndpoint.ask(StopExecutor) try { diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala new file mode 100644 index 0000000000000..ea38ccb289c30 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +import java.io.{DataInputStream, DataOutputStream, InputStream} +import java.net.Socket +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +/** + * A class that can be used to add a simple authentication protocol to socket-based communication. + * + * The protocol is simple: an auth secret is written to the socket, and the other side checks the + * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is + * not expected to be valid anymore. + * + * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. + */ +private[spark] class SocketAuthHelper(conf: SparkConf) { + + val secret = Utils.createSecret(conf) + + /** + * Read the auth secret from the socket and compare to the expected value. Write the reply back + * to the socket. + * + * If authentication fails or error is thrown, this method will close the socket. + * + * @param s The client socket. + * @throws IllegalArgumentException If authentication fails. + */ + def authClient(s: Socket): Unit = { + var shouldClose = true + try { + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + shouldClose = false + } else { + writeUtf8("err", s) + throw new IllegalArgumentException("Authentication failed.") + } + } finally { + s.setSoTimeout(currentTimeout) + } + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(s) + } + } + } + + /** + * Authenticate with a server by writing the auth secret and checking the server's reply. + * + * If authentication fails or error is thrown, this method will close the socket. + * + * @param s The socket connected to the server. + * @throws IllegalArgumentException If authentication fails. + */ + def authToServer(s: Socket): Unit = { + var shouldClose = true + try { + writeUtf8(secret, s) + + val reply = readUtf8(s) + if (reply != "ok") { + throw new IllegalArgumentException("Authentication failed.") + } else { + shouldClose = false + } + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(s) + } + } + } + + protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + new String(bytes, UTF_8) + } + + protected def writeUtf8(str: String, s: Socket): Unit = { + val bytes = str.getBytes(UTF_8) + val dout = new DataOutputStream(s.getOutputStream()) + dout.writeInt(bytes.length) + dout.write(bytes, 0, bytes.length) + dout.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4103dfb10175e..74b0e0b3a741a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -104,7 +104,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) // Use completion callback to stop sorter if task was finished/cancelled. - context.addTaskCompletionListener(_ => { + context.addTaskCompletionListener[Unit](_ => { sorter.stop() }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d9fad64f34c7c..0caf84c6050a8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.shuffle._ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then * written to a single map output file. Reducers fetch contiguous regions of this file in order to * read their portion of the map output. In cases where the map output data is too large to fit in - * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * memory, sorted subsets of the output can be spilled to disk and those on-disk files are merged * to produce the final output file. * * Sort-based shuffle has two different write paths for producing its map output files: diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 274399b9cc1f3..91fc26762e533 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,8 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, + writeMetrics.recordsWritten) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 688f25a9fdea1..e237281c552b1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -471,7 +471,7 @@ private[spark] class AppStatusStore( def operationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { val job = store.read(classOf[JobDataWrapper], jobId) - val stages = job.info.stageIds + val stages = job.info.stageIds.sorted stages.map { id => val g = store.read(classOf[RDDOperationGraphWrapper], id).toRDDOperationGraph() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 7127397f6205c..84c2ad48f1f27 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -28,7 +28,7 @@ import org.glassfish.jersey.server.ServerProperties import org.glassfish.jersey.servlet.ServletContainer import org.apache.spark.SecurityManager -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, UIUtils} /** * Main entry point for serving spark application metrics as json, using JAX-RS. @@ -49,6 +49,7 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}") def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] + @GET @Path("version") def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) @@ -147,38 +148,18 @@ private[v1] trait BaseAppResource extends ApiRequestContext { } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( - Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + UIUtils.buildErrorResponse(Response.Status.FORBIDDEN, msg)) private[v1] class NotFoundException(msg: String) extends WebApplicationException( - new NoSuchElementException(msg), - Response - .status(Response.Status.NOT_FOUND) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.NOT_FOUND, msg)) private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( - new ServiceUnavailableException(msg), - Response - .status(Response.Status.SERVICE_UNAVAILABLE) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.SERVICE_UNAVAILABLE, msg)) private[v1] class BadParameterException(msg: String) extends WebApplicationException( - new IllegalArgumentException(msg), - Response - .status(Response.Status.BAD_REQUEST) - .entity(ErrorWrapper(msg)) - .build() -) { + UIUtils.buildErrorResponse(Response.Status.BAD_REQUEST, msg)) { def this(param: String, exp: String, actual: String) = { this(raw"""Bad value for parameter "$param". Expected a $exp, got "$actual"""") } } -/** - * Signal to JacksonMessageWriter to not convert the message into json (which would result in an - * extra set of quotes). - */ -private[v1] case class ErrorWrapper(s: String) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index 76af33c1a18db..4560d300cb0c8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -68,10 +68,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ mediaType: MediaType, multivaluedMap: MultivaluedMap[String, AnyRef], outputStream: OutputStream): Unit = { - t match { - case ErrorWrapper(err) => outputStream.write(err.getBytes(StandardCharsets.UTF_8)) - case _ => mapper.writeValue(outputStream, t) - } + mapper.writeValue(outputStream, t) } override def getSize( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 974697890dd03..32100c5704538 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -140,11 +140,8 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) .build() } catch { - case NonFatal(e) => - Response.serverError() - .entity(s"Event logs are not available for app: $appId.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() + case NonFatal(_) => + throw new ServiceUnavailable(s"Event logs are not available for app: $appId.") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..e7cdfab99b34d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -41,10 +41,12 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -129,6 +131,10 @@ private[spark] class BlockManager( private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + private val chunkSize = + conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString).toInt + private val remoteReadNioBufferConversion = + conf.getBoolean("spark.network.remoteReadNioBufferConversion", false) val diskBlockManager = { // Only perform cleanup if an external service is not serving our shuffle files. @@ -291,7 +297,7 @@ private[spark] class BlockManager( case e: Exception if i < MAX_ATTEMPTS => logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) - Thread.sleep(SLEEP_TIME_SECS * 1000) + Thread.sleep(SLEEP_TIME_SECS * 1000L) case NonFatal(e) => throw new SparkException("Unable to register with external shuffle server due to : " + e.getMessage, e) @@ -401,6 +407,63 @@ private[spark] class BlockManager( putBytes(blockId, new ChunkedByteBuffer(data.nioByteBuffer()), level)(classTag) } + override def putBlockDataAsStream( + blockId: BlockId, + level: StorageLevel, + classTag: ClassTag[_]): StreamCallbackWithID = { + // TODO if we're going to only put the data in the disk store, we should just write it directly + // to the final location, but that would require a deeper refactor of this code. So instead + // we just write to a temp file, and call putBytes on the data in that file. + val tmpFile = diskBlockManager.createTempLocalBlock()._2 + val channel = new CountingWritableChannel( + Channels.newChannel(serializerManager.wrapForEncryption(new FileOutputStream(tmpFile)))) + logTrace(s"Streaming block $blockId to tmp file $tmpFile") + new StreamCallbackWithID { + + override def getID: String = blockId.name + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.hasRemaining) { + channel.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + logTrace(s"Done receiving block $blockId, now putting into local blockManager") + // Read the contents of the downloaded file as a buffer to put into the blockManager. + // Note this is all happening inside the netty thread as soon as it reads the end of the + // stream. + channel.close() + // TODO SPARK-25035 Even if we're only going to write the data to disk after this, we end up + // using a lot of memory here. With encryption, we'll read the whole file into a regular + // byte buffer and OOM. Without encryption, we'll memory map the file and won't get a jvm + // OOM, but might get killed by the OS / cluster manager. We could at least read the tmp + // file as a stream in both cases. + val buffer = securityManager.getIOEncryptionKey() match { + case Some(key) => + // we need to pass in the size of the unencrypted block + val blockSize = channel.getCount + val allocator = level.memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + new EncryptedBlockData(tmpFile, blockSize, conf, key).toChunkedByteBuffer(allocator) + + case None => + ChunkedByteBuffer.map(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt) + } + putBytes(blockId, buffer, level)(classTag) + tmpFile.delete() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + // the framework handles the connection itself, we just need to do local cleanup + channel.close() + tmpFile.delete() + } + } + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing. @@ -659,6 +722,11 @@ private[spark] class BlockManager( * Get block from remote block managers as serialized bytes. */ def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + // TODO if we change this method to return the ManagedBuffer, then getRemoteValues + // could just use the inputStream on the temp file, rather than memory-mapping the file. + // Until then, replication can cause the process to use too much memory and get killed + // by the OS / cluster manager (not a java OOM, since it's a memory-mapped file) even though + // we've read the data to disk. logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") var runningFailureCount = 0 @@ -689,7 +757,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager) } catch { case NonFatal(e) => runningFailureCount += 1 @@ -723,7 +791,14 @@ private[spark] class BlockManager( } if (data != null) { - return Some(new ChunkedByteBuffer(data)) + // SPARK-24307 undocumented "escape-hatch" in case there are any issues in converting to + // ChunkedByteBuffer, to go back to old code-path. Can be removed post Spark 2.4 if + // new path is stable. + if (remoteReadNioBufferConversion) { + return Some(new ChunkedByteBuffer(data.nioByteBuffer())) + } else { + return Some(ChunkedByteBuffer.fromManagedBuffer(data, chunkSize)) + } } logDebug(s"The value of block $blockId is null") } @@ -1341,12 +1416,16 @@ private[spark] class BlockManager( try { val onePeerStartTime = System.nanoTime logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer") + // This thread keeps a lock on the block, so we do not want the netty thread to unlock + // block when it finishes sending the message. + val buffer = new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false, + unlockOnDeallocate = false) blockTransferService.uploadBlockSync( peer.host, peer.port, peer.executorId, blockId, - new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), + buffer, tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + @@ -1554,7 +1633,7 @@ private[spark] class BlockManager( private[spark] object BlockManager { private val ID_GENERATOR = new IdGenerator - def blockIdsToHosts( + def blockIdsToLocations( blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { @@ -1569,7 +1648,9 @@ private[spark] object BlockManager { val blockManagers = new HashMap[BlockId, Seq[String]] for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i).map(_.host) + blockManagers(blockIds(i)) = blockLocations(i).map { loc => + ExecutorCacheTaskLocation(loc.host, loc.executorId).toString + } } blockManagers.toMap } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index 3d3806126676c..5c12b5cee4d2f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -38,7 +38,8 @@ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, data: BlockData, - dispose: Boolean) extends ManagedBuffer { + dispose: Boolean, + unlockOnDeallocate: Boolean = true) extends ManagedBuffer { private val refCount = new AtomicInteger(1) @@ -58,7 +59,9 @@ private[storage] class BlockManagerManagedBuffer( } override def release(): ManagedBuffer = { - blockInfoManager.unlock(blockId) + if (unlockOnDeallocate) { + blockInfoManager.unlock(blockId) + } if (refCount.decrementAndGet() == 0 && dispose) { data.dispose() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 8e8f7d197c9ef..f984cf76e3463 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -54,7 +54,8 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private val askThreadPool = + ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100) private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) private val topologyMapper = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 742cf4fe393f9..67544b20408a6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -37,7 +37,7 @@ class BlockManagerSlaveEndpoint( extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = - ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 100) private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 39249d411b582..a820bc70b33b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -29,7 +29,7 @@ import com.google.common.io.Closeables import io.netty.channel.DefaultFileRegion import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils @@ -44,8 +44,7 @@ private[spark] class DiskStore( securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") - private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", - Int.MaxValue.toString) + private val maxMemoryMapBytes = conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS) private val blockSizes = new ConcurrentHashMap[BlockId, Long]() def getSize(blockId: BlockId): Long = blockSizes.get(blockId) @@ -279,7 +278,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: override def transferred(): Long = _transferred override def transferTo(target: WritableByteChannel, pos: Long): Long = { - assert(pos == transfered(), "Invalid position.") + assert(pos == transferred(), "Invalid position.") var written = 0L var lastWrite = -1L diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index e5abbf745cc41..9ccc8f9cc585b 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -17,7 +17,9 @@ package org.apache.spark.storage +import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @@ -53,10 +55,16 @@ class RDDInfo( } private[spark] object RDDInfo { + private val callsiteForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_FORM) + def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) + val callSite = callsiteForm match { + case "short" => rdd.creationSite.shortForm + case "long" => rdd.creationSite.longForm + } new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) + rdd.getStorageLevel, parentIds, callSite, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index dd9df74689a13..00d01dd28afb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. + * order to throttle the memory usage. Note that zero-sized blocks are + * already excluded, which happened in + * [[MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -62,7 +64,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -74,8 +76,8 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ /** - * Total number of blocks to fetch. This can be smaller than the total number of blocks - * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * Total number of blocks to fetch. This should be equal to the total number of blocks + * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. * * This should equal localBlocks.size + remoteBlocks.size. */ @@ -267,13 +269,16 @@ final class ShuffleBlockFetcherIterator( // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] - // Tracks total number of blocks (including zero sized blocks) - var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId) { - // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + blockInfos.find(_._2 <= 0) match { + case Some((blockId, size)) if size < 0 => + throw new BlockException(blockId, "Negative block size " + size) + case Some((blockId, size)) if size == 0 => + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + case None => // do nothing. + } + localBlocks ++= blockInfos.map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator @@ -281,14 +286,15 @@ final class ShuffleBlockFetcherIterator( var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { + if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } else if (size == 0) { + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + } else { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { @@ -306,7 +312,8 @@ final class ShuffleBlockFetcherIterator( } } } - logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" + + s" local blocks and ${remoteBlocks.size} remote blocks") remoteRequests } @@ -339,7 +346,7 @@ final class ShuffleBlockFetcherIterator( private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. - context.addTaskCompletionListener(_ => cleanup()) + context.addTaskCompletionListener[Unit](_ => cleanup()) // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() @@ -407,6 +414,25 @@ final class ShuffleBlockFetcherIterator( logDebug("Number of requests in flight " + reqsInFlight) } + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, address, new IOException(msg)) + } + val in = try { buf.createInputStream() } catch { diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 4cc5bcb7f9baf..06fd56e54d9c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -827,7 +827,7 @@ private[storage] class PartiallySerializedBlock[T]( // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. // The dispose() method is idempotent, so it's safe to call it unconditionally. Option(TaskContext.get()).foreach { taskContext => - taskContext.addTaskCompletionListener { _ => + taskContext.addTaskCompletionListener[Unit] { _ => // When a task completes, its unroll memory will automatically be freed. Thus we do not call // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. unrolledBuffer.dispose() diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0e8a6307de6a8..52a955111231a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -263,7 +263,7 @@ private[spark] object JettyUtils extends Logging { filters.foreach { case filter : String => if (!filter.isEmpty) { - logInfo("Adding filter: " + filter) + logInfo(s"Adding filter $filter to ${handlers.map(_.getContextPath).mkString(", ")}.") val holder : FilterHolder = new FilterHolder() holder.setClassName(filter) // Get any parameters for each filter @@ -344,6 +344,7 @@ private[spark] object JettyUtils extends Logging { connectionFactories: _*) connector.setPort(port) connector.setHost(hostName) + connector.setReuseAddress(!Utils.isWindows) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads @@ -406,7 +407,7 @@ private[spark] object JettyUtils extends Logging { } pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - ServerInfo(server, httpPort, securePort, collection) + ServerInfo(server, httpPort, securePort, conf, collection) } catch { case e: Exception => server.stop() @@ -506,10 +507,12 @@ private[spark] case class ServerInfo( server: Server, boundPort: Int, securePort: Option[Int], + conf: SparkConf, private val rootHandler: ContextHandlerCollection) { - def addHandler(handler: ContextHandler): Unit = { + def addHandler(handler: ServletContextHandler): Unit = { handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) + JettyUtils.addFilters(Seq(handler), conf) rootHandler.addHandler(handler) if (!handler.isStarted()) { handler.start() diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b44ac0ea1febc..d315ef66e0dc0 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -65,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 02cf19e00ecde..732b7528f499e 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,6 +20,8 @@ package org.apache.spark.ui import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} +import javax.servlet.http.HttpServletRequest +import javax.ws.rs.core.{MediaType, Response} import scala.util.control.NonFatal import scala.xml._ @@ -148,60 +150,71 @@ private[spark] object UIUtils extends Logging { } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - def uiRoot: String = { + def uiRoot(request: HttpServletRequest): String = { + // Knox uses X-Forwarded-Context to notify the application the base path + val knoxBasePath = Option(request.getHeader("X-Forwarded-Context")) // SPARK-11484 - Use the proxyBase set by the AM, if not found then use env. sys.props.get("spark.ui.proxyBase") .orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE")) + .orElse(knoxBasePath) .getOrElse("") } - def prependBaseUri(basePath: String = "", resource: String = ""): String = { - uiRoot + basePath + resource + def prependBaseUri( + request: HttpServletRequest, + basePath: String = "", + resource: String = ""): String = { + uiRoot(request) + basePath + resource } - def commonHeaderNodes: Seq[Node] = { + def commonHeaderNodes(request: HttpServletRequest): Seq[Node] = { - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + } - def vizHeaderNodes: Seq[Node] = { - - - - - + def vizHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + + + + } - def dataTablesHeaderNodes: Seq[Node] = { + def dataTablesHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> - - - - - - - + href={prependBaseUri(request, "/static/jsonFormatter.min.css")} type="text/css"/> + + + + + + } /** Returns a spark page with correctly formatted headers */ def headerSparkPage( + request: HttpServletRequest, title: String, content: => Seq[Node], activeTab: SparkUITab, @@ -214,25 +227,26 @@ private[spark] object UIUtils extends Logging { val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } val helpButton: Seq[Node] = helpText.map(tooltip(_, "bottom")).getOrElse(Seq.empty) - {commonHeaderNodes} - {if (showVisualization) vizHeaderNodes else Seq.empty} - {if (useDataTables) dataTablesHeaderNodes else Seq.empty} - + {commonHeaderNodes(request)} + {if (showVisualization) vizHeaderNodes(request) else Seq.empty} + {if (useDataTables) dataTablesHeaderNodes(request) else Seq.empty} + {appName} - {title} } - UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) + UIUtils.headerSparkPage( + request, s"Details for Job $jobId", content, parent, showVisualization = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index a3e1f13782e30..22a40101e33df 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -49,7 +49,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { "stages/pool", parent.isFairScheduler, parent.killEnabled, false) val poolTable = new PoolTable(Map(pool -> uiPool), parent) - var content =

    Summary

    ++ poolTable.toNodeSeq + var content =

    Summary

    ++ poolTable.toNodeSeq(request) if (activeStages.nonEmpty) { content ++= } - UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) + UIUtils.headerSparkPage(request, "Fair Scheduler Pool: " + poolName, content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 5dfce858dec07..96b5f72393070 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder +import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -28,7 +29,7 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab) { - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = { @@ -39,15 +40,15 @@ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab - {pools.map { case (s, p) => poolRow(s, p) }} + {pools.map { case (s, p) => poolRow(request, s, p) }}
    Pool NameSchedulingMode
    } - private def poolRow(s: Schedulable, p: PoolData): Seq[Node] = { + private def poolRow(request: HttpServletRequest, s: Schedulable, p: PoolData): Seq[Node] = { val activeStages = p.stageIds.size val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) + .format(UIUtils.prependBaseUri(request, parent.basePath), URLEncoder.encode(p.name, "UTF-8")) {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ac83de10f9237..55eb989962668 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,20 +112,19 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    No information to display for Stage {stageId} (Attempt {stageAttemptId})

    - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks + val totalTasks = taskCount(stageData) if (totalTasks == 0) { val content =

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet
    - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) @@ -133,7 +132,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${storedTasks}" + s"$storedTasks, showing ${totalTasks}" } val summary = @@ -282,8 +281,8 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( stageData, - UIUtils.prependBaseUri(parent.basePath) + - s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + UIUtils.prependBaseUri(request, parent.basePath) + + s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, @@ -498,7 +497,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    {taskTableHTML ++ jsForScrollingDownToTaskTable}
    - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) } def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { @@ -686,7 +685,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numTasks + override def dataSize: Int = taskCount(stage) override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { @@ -1052,4 +1051,9 @@ private[ui] object ApiHelper { (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } + def taskCount(stageData: StageData): Int = { + stageData.numActiveTasks + stageData.numCompleteTasks + stageData.numFailedTasks + + stageData.numKilledTasks + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 18a4926f2f6c0..d01acdae59c9f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -43,7 +43,9 @@ private[ui] class StageTableBase( killEnabled: Boolean, isFailedStage: Boolean) { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) @@ -92,7 +94,8 @@ private[ui] class StageTableBase( stageSortColumn, stageSortDesc, isFailedStage, - parameterOtherTable + parameterOtherTable, + request ).table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -147,7 +150,8 @@ private[ui] class StagePagedTable( sortColumn: String, desc: Boolean, isFailedStage: Boolean, - parameterOtherTable: Iterable[String]) extends PagedTable[StageTableRowData] { + parameterOtherTable: Iterable[String], + request: HttpServletRequest) extends PagedTable[StageTableRowData] { override def tableId: String = stageTag + "-table" @@ -161,7 +165,7 @@ private[ui] class StagePagedTable( override def pageNumberFormField: String = stageTag + ".page" - val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" + + val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + parameterOtherTable.mkString("&") override val dataSource = new StageDataSource( @@ -288,7 +292,7 @@ private[ui] class StagePagedTable( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(request, basePath), data.schedulingPool)}> {data.schedulingPool} @@ -346,7 +350,7 @@ private[ui] class StagePagedTable( } private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { - val basePathUri = UIUtils.prependBaseUri(basePath) + val basePathUri = UIUtils.prependBaseUri(request, basePath) val killLink = if (killEnabled) { val confirm = @@ -366,7 +370,7 @@ private[ui] class StagePagedTable( Seq.empty } - val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" + val nameLinkUri = s"$basePathUri/stages/stage/?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = store.rddList().filter { rdd => s.rddIds.contains(rdd.id) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 2674b9291203a..238cd31433660 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -53,7 +53,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } catch { case _: NoSuchElementException => // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) + return UIUtils.headerSparkPage(request, "RDD Not Found", Seq.empty[Node], parent) } // Worker table @@ -72,7 +72,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } val blockTableHTML = try { val _blockTable = new BlockPagedTable( - UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", + UIUtils.prependBaseUri(request, parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, @@ -145,7 +145,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web {blockTableHTML ++ jsForScrollingDownToBlockTable} ; - UIUtils.headerSparkPage("RDD Storage Info for " + rddStorageInfo.name, content, parent) + UIUtils.headerSparkPage( + request, "RDD Storage Info for " + rddStorageInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 68d946574a37b..3eb546e336e99 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -31,11 +31,14 @@ import org.apache.spark.util.Utils private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - val content = rddTable(store.rddList()) ++ receiverBlockTables(store.streamBlocksList()) - UIUtils.headerSparkPage("Storage", content, parent) + val content = rddTable(request, store.rddList()) ++ + receiverBlockTables(store.streamBlocksList()) + UIUtils.headerSparkPage(request, "Storage", content, parent) } - private[storage] def rddTable(rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { + private[storage] def rddTable( + request: HttpServletRequest, + rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { if (rdds.isEmpty) { // Don't show the rdd table if there is no RDD persisted. Nil @@ -49,7 +52,11 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends
    - {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + {UIUtils.listingTable( + rddHeader, + rddRow(request, _: v1.RDDStorageInfo), + rdds, + id = Some("storage-by-rdd-table"))}
    } @@ -66,12 +73,13 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends "Size on Disk") /** Render an HTML row representing an RDD */ - private def rddRow(rdd: v1.RDDStorageInfo): Seq[Node] = { + private def rddRow(request: HttpServletRequest, rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off {rdd.id} - + {rdd.name} diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 0f84ea9752cf5..bf618b4afbce0 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo private[spark] case class AccumulatorMetadata( @@ -199,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } override def toString: String = { + // getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead if (metadata == null) { - "Un-registered Accumulator: " + getClass.getSimpleName + "Un-registered Accumulator: " + Utils.getSimpleName(getClass) } else { - getClass.getSimpleName + s"(id: $id, name: $name, value: $value)" + Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)" } } } @@ -211,7 +214,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { /** * An internal class used to track accumulators by Spark itself. */ -private[spark] object AccumulatorContext { +private[spark] object AccumulatorContext extends Logging { /** * This global map holds the original accumulator objects that are created on the driver. @@ -258,13 +261,16 @@ private[spark] object AccumulatorContext { * Returns the [[AccumulatorV2]] registered with the given ID, if any. */ def get(id: Long): Option[AccumulatorV2[_, _]] = { - Option(originals.get(id)).map { ref => - // Since we are storing weak references, we must check whether the underlying data is valid. + val ref = originals.get(id) + if (ref eq null) { + None + } else { + // Since we are storing weak references, warn when the underlying data is not valid. val acc = ref.get if (acc eq null) { - throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") + logWarning(s"Attempted to access garbage collected accumulator $id") } - acc + Option(acc) } } @@ -486,7 +492,9 @@ class LegacyAccumulatorWrapper[R, T]( param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero: Boolean = _value == param.zero(initialValue) + @transient private lazy val _zero = param.zero(initialValue) + + override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) @@ -495,7 +503,7 @@ class LegacyAccumulatorWrapper[R, T]( } override def reset(): Unit = { - _value = param.zero(initialValue) + _value = _zero } override def add(v: T): Unit = _value = param.addAccumulator(_value, v) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ad0c0639521f6..b6c300c4778b1 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -18,12 +18,13 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.lang.invoke.SerializedLambda import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging @@ -33,6 +34,8 @@ import org.apache.spark.internal.Logging */ private[spark] object ClosureCleaner extends Logging { + private val isScala2_11 = scala.util.Properties.versionString.contains("2.11") + // Get an ASM class reader for a given class from the JAR that loaded it private[util] def getClassReader(cls: Class[_]): ClassReader = { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. @@ -159,6 +162,42 @@ private[spark] object ClosureCleaner extends Logging { clean(closure, checkSerializable, cleanTransitively, Map.empty) } + /** + * Try to get a serialized Lambda from the closure. + * + * @param closure the closure to check. + */ + private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = { + if (isScala2_11) { + return None + } + val isClosureCandidate = + closure.getClass.isSynthetic && + closure + .getClass + .getInterfaces.exists(_.getName.equals("scala.Serializable")) + + if (isClosureCandidate) { + try { + Option(inspect(closure)) + } catch { + case e: Exception => + // no need to check if debug is enabled here the Spark + // logging api covers this. + logDebug("Closure is not a serialized lambda.", e) + None + } + } else { + None + } + } + + private def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda] + } + /** * Helper method to clean the given closure in place. * @@ -206,7 +245,12 @@ private[spark] object ClosureCleaner extends Logging { cleanTransitively: Boolean, accessedFields: Map[Class[_], Set[String]]): Unit = { - if (!isClosure(func.getClass)) { + // most likely to be the case with 2.12, 2.13 + // so we check first + // non LMF-closures should be less frequent from now on + val lambdaFunc = getSerializedLambda(func) + + if (!isClosure(func.getClass) && lambdaFunc.isEmpty) { logDebug(s"Expected a closure; got ${func.getClass.getName}") return } @@ -218,118 +262,132 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") - - // A list of classes that represents closures enclosed in the given one - val innerClasses = getInnerClosureClasses(func) - - // A list of enclosing objects and their respective classes, from innermost to outermost - // An outer object at a given index is of type outer class at the same index - val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) - - // For logging purposes only - val declaredFields = func.getClass.getDeclaredFields - val declaredMethods = func.getClass.getDeclaredMethods - - if (log.isDebugEnabled) { - logDebug(" + declared fields: " + declaredFields.size) - declaredFields.foreach { f => logDebug(" " + f) } - logDebug(" + declared methods: " + declaredMethods.size) - declaredMethods.foreach { m => logDebug(" " + m) } - logDebug(" + inner classes: " + innerClasses.size) - innerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer classes: " + outerClasses.size) - outerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer objects: " + outerObjects.size) - outerObjects.foreach { o => logDebug(" " + o) } - } + if (lambdaFunc.isEmpty) { + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") + + // A list of classes that represents closures enclosed in the given one + val innerClasses = getInnerClosureClasses(func) + + // A list of enclosing objects and their respective classes, from innermost to outermost + // An outer object at a given index is of type outer class at the same index + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) + + // For logging purposes only + val declaredFields = func.getClass.getDeclaredFields + val declaredMethods = func.getClass.getDeclaredMethods + + if (log.isDebugEnabled) { + logDebug(s" + declared fields: ${declaredFields.size}") + declaredFields.foreach { f => logDebug(s" $f") } + logDebug(s" + declared methods: ${declaredMethods.size}") + declaredMethods.foreach { m => logDebug(s" $m") } + logDebug(s" + inner classes: ${innerClasses.size}") + innerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer classes: ${outerClasses.size}" ) + outerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer objects: ${outerObjects.size}") + outerObjects.foreach { o => logDebug(s" $o") } + } - // Fail fast if we detect return statements in closures - getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) - - // If accessed fields is not populated yet, we assume that - // the closure we are trying to clean is the starting one - if (accessedFields.isEmpty) { - logDebug(s" + populating accessed fields because this is the starting closure") - // Initialize accessed fields with the outer classes first - // This step is needed to associate the fields to the correct classes later - initAccessedFields(accessedFields, outerClasses) - - // Populate accessed fields by visiting all fields and methods accessed by this and - // all of its inner closures. If transitive cleaning is enabled, this may recursively - // visits methods that belong to other classes in search of transitively referenced fields. - for (cls <- func.getClass :: innerClasses) { - getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + // Fail fast if we detect return statements in closures + getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) + + // If accessed fields is not populated yet, we assume that + // the closure we are trying to clean is the starting one + if (accessedFields.isEmpty) { + logDebug(" + populating accessed fields because this is the starting closure") + // Initialize accessed fields with the outer classes first + // This step is needed to associate the fields to the correct classes later + initAccessedFields(accessedFields, outerClasses) + + // Populate accessed fields by visiting all fields and methods accessed by this and + // all of its inner closures. If transitive cleaning is enabled, this may recursively + // visits methods that belong to other classes in search of transitively referenced fields. + for (cls <- func.getClass :: innerClasses) { + getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + } } - } - logDebug(s" + fields accessed by starting closure: " + accessedFields.size) - accessedFields.foreach { f => logDebug(" " + f) } - - // List of outer (class, object) pairs, ordered from outermost to innermost - // Note that all outer objects but the outermost one (first one in this list) must be closures - var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var parent: AnyRef = null - if (outerPairs.size > 0) { - val (outermostClass, outermostObject) = outerPairs.head - if (isClosure(outermostClass)) { - logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") - } else if (outermostClass.getName.startsWith("$line")) { - // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it - // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc. - logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + logDebug(s" + fields accessed by starting closure: " + accessedFields.size) + accessedFields.foreach { f => logDebug(" " + f) } + + // List of outer (class, object) pairs, ordered from outermost to innermost + // Note that all outer objects but the outermost one (first one in this list) must be closures + var outerPairs: List[(Class[_], AnyRef)] = outerClasses.zip(outerObjects).reverse + var parent: AnyRef = null + if (outerPairs.nonEmpty) { + val (outermostClass, outermostObject) = outerPairs.head + if (isClosure(outermostClass)) { + logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone + // and clean it as it may carray a lot of unnecessary information, + // e.g. hadoop conf, spark conf, etc. + logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(" + outermost object is not a closure or REPL line object," + + "so do not clone it: " + outerPairs.head) + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } } else { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - // Note that we still need to keep around the outermost object itself because - // we need it to clone its child closure later (see below). - logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " + - outerPairs.head) - parent = outermostObject // e.g. SparkContext - outerPairs = outerPairs.tail + logDebug(" + there are no enclosing objects!") } - } else { - logDebug(" + there are no enclosing objects!") - } - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - logDebug(s" + cloning the object $obj of class ${cls.getName}") - // We null out these unused references by cloning each object and then filling in all - // required fields from the original object. We need the parent here because the Java - // language specification requires the first constructor parameter of any closure to be - // its enclosing object. - val clone = cloneAndSetFields(parent, obj, cls, accessedFields) - - // If transitive cleaning is enabled, we recursively clean any enclosing closure using - // the already populated accessed fields map of the starting closure - if (cleanTransitively && isClosure(clone.getClass)) { - logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") - // No need to check serializable here for the outer closures because we're - // only interested in the serializability of the starting closure - clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + // Clone the closure objects themselves, nulling out any fields that are not + // used in the closure we're working on or any of its inner closures. + for ((cls, obj) <- outerPairs) { + logDebug(s" + cloning the object $obj of class ${cls.getName}") + // We null out these unused references by cloning each object and then filling in all + // required fields from the original object. We need the parent here because the Java + // language specification requires the first constructor parameter of any closure to be + // its enclosing object. + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + + // If transitive cleaning is enabled, we recursively clean any enclosing closure using + // the already populated accessed fields map of the starting closure + if (cleanTransitively && isClosure(clone.getClass)) { + logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") + // No need to check serializable here for the outer closures because we're + // only interested in the serializability of the starting closure + clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + } + parent = clone } - parent = clone - } - // Update the parent pointer ($outer) of this closure - if (parent != null) { - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - // If the starting closure doesn't actually need our enclosing object, then just null it out - if (accessedFields.contains(func.getClass) && - !accessedFields(func.getClass).contains("$outer")) { - logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") - field.set(func, null) - } else { - // Update this closure's parent pointer to point to our enclosing object, - // which could either be a cloned closure or the original user object - field.set(func, parent) + // Update the parent pointer ($outer) of this closure + if (parent != null) { + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + // If the starting closure doesn't actually need our enclosing object, then just null it out + if (accessedFields.contains(func.getClass) && + !accessedFields(func.getClass).contains("$outer")) { + logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") + field.set(func, null) + } else { + // Update this closure's parent pointer to point to our enclosing object, + // which could either be a cloned closure or the original user object + field.set(func, parent) + } } - } - logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + } else { + logDebug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}") + + // scalastyle:off classforname + val captClass = Class.forName(lambdaFunc.get.getCapturingClass.replace('/', '.'), + false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + // Fail fast if we detect return statements in closures + getClassReader(captClass) + .accept(new ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0) + logDebug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is now cleaned +++") + } if (checkSerializable) { ensureSerializable(func) @@ -366,20 +424,30 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") -private class ReturnStatementFinder extends ClassVisitor(ASM5) { +private class ReturnStatementFinder(targetMethodName: Option[String] = None) + extends ClassVisitor(ASM6) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { + // $anonfun$ covers Java 8 lambdas if (name.contains("apply") || name.contains("$anonfun$")) { - new MethodVisitor(ASM5) { + // A method with suffix "$adapted" will be generated in cases like + // { _:Int => return; Seq()} but not { _:Int => return; true} + // closure passed is $anonfun$t$1$adapted while actual code resides in $anonfun$s$1 + // visitor will see only $anonfun$s$1$adapted, so we remove the suffix, see + // https://github.com/scala/scala-dev/issues/109 + val isTargetMethod = targetMethodName.isEmpty || + name == targetMethodName.get || name == targetMethodName.get.stripSuffix("$adapted") + + new MethodVisitor(ASM6) { override def visitTypeInsn(op: Int, tp: String) { - if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { + if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") && isTargetMethod) { throw new ReturnStatementInClosureException } } } } else { - new MethodVisitor(ASM5) {} + new MethodVisitor(ASM6) {} } } } @@ -403,7 +471,7 @@ private[util] class FieldAccessFinder( findTransitively: Boolean, specificMethod: Option[MethodIdentifier[_]] = None, visitedMethods: Set[MethodIdentifier[_]] = Set.empty) - extends ClassVisitor(ASM5) { + extends ClassVisitor(ASM6) { override def visitMethod( access: Int, @@ -418,7 +486,7 @@ private[util] class FieldAccessFinder( return null } - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { @@ -458,7 +526,7 @@ private[util] class FieldAccessFinder( } } -private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM6) { var myName: String = null // TODO: Recursively find inner closures that we indirectly reference, e.g. @@ -473,7 +541,7 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitMethodInsn( op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index 3ea9139e11027..651ea4996f6cb 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { private val stopped = new AtomicBoolean(false) - private val eventThread = new Thread(name) { + // Exposed for testing. + private[spark] val eventThread = new Thread(name) { setDaemon(true) override def run(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 40383fe05026b..50c6461373dee 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -407,7 +407,9 @@ private[spark] object JsonProtocol { ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => - ("Kill Reason" -> taskKilled.reason) + val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList) + ("Kill Reason" -> taskKilled.reason) ~ + ("Accumulator Updates" -> accumUpdates) case _ => emptyJson } ("Reason" -> reason) ~ json @@ -917,7 +919,10 @@ private[spark] object JsonProtocol { case `taskKilled` => val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") - TaskKilled(killReason) + val accumUpdates = jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(Seq[AccumulableInfo]()) + TaskKilled(killReason, accumUpdates) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index b25a731401f23..a8f10684d5a2c 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * This can be overridden by subclasses if there is any extra cleanup to do when removing a + * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. + */ + def removeListenerOnError(listener: L): Unit = { + removeListener(listener) + } + + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } try { doPostEvent(listener, event) + if (Thread.interrupted()) { + // We want to throw the InterruptedException right away so we can associate the interrupt + // with this listener, as opposed to waiting for a queue.take() etc. to detect it. + throw new InterruptedException() + } } catch { + case ie: InterruptedException => + logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " + + s"Removing that listener.", ie) + removeListenerOnError(listener) case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { diff --git a/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala new file mode 100644 index 0000000000000..1aa2009fa9b5b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +/** + * SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we catch + * fatal throwable in {@link scala.concurrent.Future}'s body, and re-throw + * SparkFatalException, which wraps the fatal throwable inside. + * Note that SparkFatalException should only be thrown from a {@link scala.concurrent.Future}, + * which is run by using ThreadUtils.awaitResult. ThreadUtils.awaitResult will catch + * it and re-throw the original exception/error. + */ +private[spark] final class SparkFatalException(val throwable: Throwable) extends Exception diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index e0f5af5250e7f..1b34fbde38cd6 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -39,10 +39,15 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) if (!ShutdownHookManager.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(SparkExitCode.OOM) - } else if (exitOnUncaughtException) { - System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + exception match { + case _: OutOfMemoryError => + System.exit(SparkExitCode.OOM) + case e: SparkFatalException if e.throwable.isInstanceOf[OutOfMemoryError] => + // SPARK-24294: This is defensive code, in case that SparkFatalException is + // misused and uncaught. + System.exit(SparkExitCode.OOM) + case _ if exitOnUncaughtException => + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 81aaf79db0c13..f0e5addbe5b56 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,12 +19,15 @@ package org.apache.spark.util import java.util.concurrent._ -import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} -import scala.concurrent.duration.Duration -import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} -import scala.util.control.NonFatal +import scala.collection.TraversableLike +import scala.collection.generic.CanBuildFrom +import scala.language.higherKinds import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future} +import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} +import scala.util.control.NonFatal import org.apache.spark.SparkException @@ -103,6 +106,22 @@ private[spark] object ThreadUtils { executor } + /** + * Wrapper over ScheduledThreadPoolExecutor. + */ + def newDaemonThreadPoolScheduledExecutor(threadNamePrefix: String, numThreads: Int) + : ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(s"$threadNamePrefix-%d") + .build() + val executor = new ScheduledThreadPoolExecutor(numThreads, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + /** * Run a piece of code in a new thread and return the result. Exception in the new thread is * thrown in the caller thread with an adjusted stack trace that removes references to this @@ -200,6 +219,8 @@ private[spark] object ThreadUtils { val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] awaitable.result(atMost)(awaitPermission) } catch { + case e: SparkFatalException => + throw e.throwable // TimeoutException is thrown in the current thread, so not need to warp the exception. case NonFatal(t) if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) @@ -227,4 +248,72 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitready + + def shutdown( + executor: ExecutorService, + gracePeriod: Duration = FiniteDuration(30, TimeUnit.SECONDS)): Unit = { + executor.shutdown() + executor.awaitTermination(gracePeriod.toMillis, TimeUnit.MILLISECONDS) + if (!executor.isShutdown) { + executor.shutdownNow() + } + } + + /** + * Transforms input collection by applying the given function to each element in parallel fashion. + * Comparing to the map() method of Scala parallel collections, this method can be interrupted + * at any time. This is useful on canceling of task execution, for example. + * + * @param in - the input collection which should be transformed in parallel. + * @param prefix - the prefix assigned to the underlying thread pool. + * @param maxThreads - maximum number of thread can be created during execution. + * @param f - the lambda function will be applied to each element of `in`. + * @tparam I - the type of elements in the input collection. + * @tparam O - the type of elements in resulted collection. + * @return new collection in which each element was given from the input collection `in` by + * applying the lambda function `f`. + */ + def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]] + (in: Col[I], prefix: String, maxThreads: Int) + (f: I => O) + (implicit + cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map + cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence + ): Col[O] = { + val pool = newForkJoinPool(prefix, maxThreads) + try { + implicit val ec = ExecutionContext.fromExecutor(pool) + + parmap(in)(f) + } finally { + pool.shutdownNow() + } + } + + /** + * Transforms input collection by applying the given function to each element in parallel fashion. + * Comparing to the map() method of Scala parallel collections, this method can be interrupted + * at any time. This is useful on canceling of task execution, for example. + * + * @param in - the input collection which should be transformed in parallel. + * @param f - the lambda function will be applied to each element of `in`. + * @param ec - an execution context for parallel applying of the given function `f`. + * @tparam I - the type of elements in the input collection. + * @tparam O - the type of elements in resulted collection. + * @return new collection in which each element was given from the input collection `in` by + * applying the lambda function `f`. + */ + def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]] + (in: Col[I]) + (f: I => O) + (implicit + cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map + cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence + ec: ExecutionContext + ): Col[O] = { + val futures = in.map(x => Future(f(x))) + val futureSeq = Future.sequence(futures) + + awaitResult(futureSeq, Duration.Inf) + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d2be93226e2a2..e6646bd073c6b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,8 @@ package org.apache.spark.util import java.io._ +import java.lang.{Byte => JByte} +import java.lang.InternalError import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -26,11 +28,12 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files +import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ +import java.util.concurrent.TimeUnit.NANOSECONDS import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream -import javax.net.ssl.HttpsURLConnection import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -44,6 +47,7 @@ import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.hash.HashCodes import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils @@ -79,6 +83,7 @@ private[spark] object Utils extends Logging { val random = new Random() private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler + @volatile private var cachedLocalDir: String = "" /** * Define a default value for driver memory here since this value is referenced across the code @@ -96,7 +101,7 @@ private[spark] object Utils extends Logging { */ val DEFAULT_MAX_TO_STRING_FIELDS = 25 - private def maxNumToStringFields = { + private[spark] def maxNumToStringFields = { if (SparkEnv.get != null) { SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) } else { @@ -431,7 +436,7 @@ private[spark] object Utils extends Logging { new URI("file:///" + rawFileName).getPath.substring(1) } - /** + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -458,7 +463,15 @@ private[spark] object Utils extends Logging { if (useCache && fetchCacheEnabled) { val cachedFileName = s"${url.hashCode}${timestamp}_cache" val lockFileName = s"${url.hashCode}${timestamp}_lock" - val localDir = new File(getLocalDir(conf)) + // Set the cachedLocalDir for the first time and re-use it later + if (cachedLocalDir.isEmpty) { + this.synchronized { + if (cachedLocalDir.isEmpty) { + cachedLocalDir = getLocalDir(conf) + } + } + } + val localDir = new File(cachedLocalDir) val lockFile = new File(localDir, lockFileName) val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel() // Only one executor entry. @@ -504,6 +517,14 @@ private[spark] object Utils extends Logging { targetFile } + /** Records the duration of running `body`. */ + def timeTakenMs[T](body: => T): (T, Long) = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + (result, math.max(NANOSECONDS.toMillis(endTime - startTime), 0)) + } + /** * Download `in` to `tempFile`, then move it to `destFile`. * @@ -755,13 +776,17 @@ private[spark] object Utils extends Logging { * - Otherwise, this will return java.io.tmpdir. * * Some of these configuration options might be lists of multiple paths, but this method will - * always return a single directory. + * always return a single directory. The return directory is chosen randomly from the array + * of directories it gets from getOrCreateLocalRootDirs. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val localRootDirs = getOrCreateLocalRootDirs(conf) + if (localRootDirs.isEmpty) { val configuredLocalDirs = getConfiguredLocalDirs(conf) throw new IOException( s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } else { + localRootDirs(scala.util.Random.nextInt(localRootDirs.length)) } } @@ -803,20 +828,20 @@ private[spark] object Utils extends Logging { // to what Yarn on this system said was available. Note this assumes that Yarn has // created the directories already, and that they are secured so that only the // user has access to them. - getYarnLocalDirs(conf).split(",") + randomizeInPlace(getYarnLocalDirs(conf).split(",")) } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { conf.getenv("SPARK_LOCAL_DIRS").split(",") - } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { + } else if (conf.getenv("MESOS_SANDBOX") != null && !shuffleServiceEnabled) { // Mesos already creates a directory per Mesos task. Spark should use that directory // instead so all temporary files are automatically cleaned up when the Mesos task ends. // Note that we don't want this if the shuffle service is enabled because we want to // continue to serve shuffle files after the executors that wrote them have already exited. - Array(conf.getenv("MESOS_DIRECTORY")) + Array(conf.getenv("MESOS_SANDBOX")) } else { - if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { - logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + + if (conf.getenv("MESOS_SANDBOX") != null && shuffleServiceEnabled) { + logInfo("MESOS_SANDBOX available but not using provided Mesos sandbox because " + "spark.shuffle.service.enabled is enabled.") } // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user @@ -1384,13 +1409,14 @@ private[spark] object Utils extends Logging { } } + // A regular expression to match classes of the internal Spark API's + // that we want to skip when finding the call site of a method. + private val SPARK_CORE_CLASS_REGEX = + """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r + private val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r + /** Default filtering function for finding call sites using `getCallSite`. */ private def sparkInternalExclusionFunction(className: String): Boolean = { - // A regular expression to match classes of the internal Spark API's - // that we want to skip when finding the call site of a method. - val SPARK_CORE_CLASS_REGEX = - """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r - val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r val SCALA_CORE_CLASS_PREFIX = "scala" val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined || SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined @@ -1818,7 +1844,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ def getFormattedClassName(obj: AnyRef): String = { - obj.getClass.getSimpleName.replace("$", "") + getSimpleName(obj.getClass).replace("$", "") } /** @@ -2689,6 +2715,86 @@ private[spark] object Utils extends Logging { s"k8s://$resolvedURL" } + + /** + * Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id + * and {{APP_ID}} occurrences with the App Id. + */ + def substituteAppNExecIds(opt: String, appId: String, execId: String): String = { + opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId) + } + + /** + * Replaces all the {{APP_ID}} occurrences with the App Id. + */ + def substituteAppId(opt: String, appId: String): String = { + opt.replace("{{APP_ID}}", appId) + } + + def createSecret(conf: SparkConf): String = { + val bits = conf.get(AUTH_SECRET_BIT_LENGTH) + val rnd = new SecureRandom() + val secretBytes = new Array[Byte](bits / JByte.SIZE) + rnd.nextBytes(secretBytes) + HashCodes.fromBytes(secretBytes).toString() + } + + /** + * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. + * This method mimicks scalatest's getSimpleNameOfAnObjectsClass. + */ + def getSimpleName(cls: Class[_]): String = { + try { + return cls.getSimpleName + } catch { + case err: InternalError => return stripDollars(stripPackages(cls.getName)) + } + } + + /** + * Remove the packages from full qualified class name + */ + private def stripPackages(fullyQualifiedName: String): String = { + fullyQualifiedName.split("\\.").takeRight(1)(0) + } + + /** + * Remove trailing dollar signs from qualified class name, + * and return the trailing part after the last dollar sign in the middle + */ + private def stripDollars(s: String): String = { + val lastDollarIndex = s.lastIndexOf('$') + if (lastDollarIndex < s.length - 1) { + // The last char is not a dollar sign + if (lastDollarIndex == -1 || !s.contains("$iw")) { + // The name does not have dollar sign or is not an intepreter + // generated class, so we should return the full string + s + } else { + // The class name is intepreter generated, + // return the part after the last dollar sign + // This is the same behavior as getClass.getSimpleName + s.substring(lastDollarIndex + 1) + } + } + else { + // The last char is a dollar sign + // Find last non-dollar char + val lastNonDollarChar = s.reverse.find(_ != '$') + lastNonDollarChar match { + case None => s + case Some(c) => + val lastNonDollarIndex = s.lastIndexOf(c) + if (lastNonDollarIndex == -1) { + s + } else { + // Strip the trailing dollar signs + // Invoke stripDollars again to get the simple name + stripDollars(s.substring(0, lastNonDollarIndex + 1)) + } + } + } + } } private[util] object CallerContext extends Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 5c6dd45ec58e3..19ff109b673e1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -80,7 +80,10 @@ class ExternalAppendOnlyMap[K, V, C]( this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) } - @volatile private var currentMap = new SizeTrackingAppendOnlyMap[K, C] + /** + * Exposed for testing + */ + @volatile private[collection] var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager @@ -267,7 +270,7 @@ class ExternalAppendOnlyMap[K, V, C]( */ def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = { readingIterator = new SpillableIterator(inMemoryIterator) - readingIterator + readingIterator.toCompletionIterator } /** @@ -280,8 +283,7 @@ class ExternalAppendOnlyMap[K, V, C]( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { - CompletionIterator[(K, C), Iterator[(K, C)]]( - destructiveIterator(currentMap.iterator), freeCurrentMap()) + destructiveIterator(currentMap.iterator) } else { new ExternalIterator() } @@ -305,8 +307,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator( - currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap()) + private val sortedMap = destructiveIterator( + currentMap.destructiveSortedIterator(keyComparator)) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -565,16 +567,14 @@ class ExternalAppendOnlyMap[K, V, C]( } } - context.addTaskCompletionListener(context => cleanup()) + context.addTaskCompletionListener[Unit](context => cleanup()) } - private[this] class SpillableIterator(var upstream: Iterator[(K, C)]) + private class SpillableIterator(var upstream: Iterator[(K, C)]) extends Iterator[(K, C)] { private val SPILL_LOCK = new Object() - private var nextUpstream: Iterator[(K, C)] = null - private var cur: (K, C) = readNext() private var hasSpilled: Boolean = false @@ -585,17 +585,24 @@ class ExternalAppendOnlyMap[K, V, C]( } else { logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") - nextUpstream = spillMemoryIteratorToDisk(upstream) + val nextUpstream = spillMemoryIteratorToDisk(upstream) + assert(!upstream.hasNext) hasSpilled = true + upstream = nextUpstream true } } + private def destroy(): Unit = { + freeCurrentMap() + upstream = Iterator.empty + } + + def toCompletionIterator: CompletionIterator[(K, C), SpillableIterator] = { + CompletionIterator[(K, C), SpillableIterator](this, this.destroy) + } + def readNext(): (K, C) = SPILL_LOCK.synchronized { - if (nextUpstream != null) { - upstream = nextUpstream - nextUpstream = null - } if (upstream.hasNext) { upstream.next() } else { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d2..b159200d79222 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -368,8 +368,8 @@ private[spark] class ExternalSorter[K, V, C]( val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { - // Use the reverse of comparator.compare because PriorityQueue dequeues the max - override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + // Use the reverse order because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1) }) heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true new Iterator[Product2[K, C]] { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60f6f537c1d54..8883e17bf3164 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -28,9 +28,9 @@ import org.apache.spark.annotation.Private * removed. * * The underlying implementation uses Scala compiler's specialization to generate optimized - * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet - * while incurring much less memory overhead. This can serve as building blocks for higher level - * data structures such as an optimized HashMap. + * storage for four primitive types (Long, Int, Double, and Float). It is much faster than Java's + * standard HashSet while incurring much less memory overhead. This can serve as building blocks + * for higher level data structures such as an optimized HashMap. * * This OpenHashSet is designed to serve as building blocks for higher level data structures * such as an optimized hash map. Compared with standard hash set implementations, this class @@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ @Private -class OpenHashSet[@specialized(Long, Int) T: ClassTag]( +class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( initialCapacity: Int, loadFactor: Double) extends Serializable { @@ -77,6 +77,10 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( (new LongHasher).asInstanceOf[Hasher[T]] } else if (mt == ClassTag.Int) { (new IntHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Double) { + (new DoubleHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Float) { + (new FloatHasher).asInstanceOf[Hasher[T]] } else { new Hasher[T] } @@ -293,7 +297,7 @@ object OpenHashSet { * A set of specialized hash function implementation to avoid boxing hash code computation * in the specialized implementation of OpenHashSet. */ - sealed class Hasher[@specialized(Long, Int) T] extends Serializable { + sealed class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable { def hash(o: T): Int = o.hashCode() } @@ -305,6 +309,17 @@ object OpenHashSet { override def hash(o: Int): Int = o } + class DoubleHasher extends Hasher[Double] { + override def hash(o: Double): Int = { + val bits = java.lang.Double.doubleToLongBits(o) + (bits ^ (bits >>> 32)).toInt + } + } + + class FloatHasher extends Hasher[Float] { + override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o) + } + private def grow1(newSize: Int) {} private def move1(oldPos: Int, newPos: Int) { } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7367af7888bd8..39f050f6ca5ad 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -17,17 +17,21 @@ package org.apache.spark.util.io -import java.io.InputStream +import java.io.{File, FileInputStream, InputStream} import java.nio.ByteBuffer -import java.nio.channels.WritableByteChannel +import java.nio.channels.{FileChannel, WritableByteChannel} +import java.nio.file.StandardOpenOption + +import scala.collection.mutable.ListBuffer import com.google.common.primitives.UnsignedBytes -import io.netty.buffer.{ByteBuf, Unpooled} import org.apache.spark.SparkEnv import org.apache.spark.internal.config +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.storage.StorageUtils +import org.apache.spark.util.Utils /** * Read-only byte buffer which is physically stored as multiple chunks rather than a single @@ -63,19 +67,28 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining() > 0) { + val originalLimit = bytes.limit() + while (bytes.hasRemaining) { + // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct + // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread. + // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may + // cause significant native memory leak, if a large direct ByteBuffer is allocated and + // cached, as it's never released until thread exits. Here we write the `bytes` with + // fixed-size slices to limit the size of the cached direct ByteBuffer. + // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details. val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) bytes.limit(bytes.position() + ioSize) channel.write(bytes) + bytes.limit(originalLimit) } } } /** - * Wrap this buffer to view it as a Netty ByteBuf. + * Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB. */ - def toNetty: ByteBuf = { - Unpooled.wrappedBuffer(chunks.length, getChunks(): _*) + def toNetty: ChunkedByteBufferFileRegion = { + new ChunkedByteBufferFileRegion(this, bufferWriteChunkSize) } /** @@ -157,6 +170,38 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } +object ChunkedByteBuffer { + // TODO eliminate this method if we switch BlockManager to getting InputStreams + def fromManagedBuffer(data: ManagedBuffer, maxChunkSize: Int): ChunkedByteBuffer = { + data match { + case f: FileSegmentManagedBuffer => + map(f.getFile, maxChunkSize, f.getOffset, f.getLength) + case other => + new ChunkedByteBuffer(other.nioByteBuffer()) + } + } + + def map(file: File, maxChunkSize: Int): ChunkedByteBuffer = { + map(file, maxChunkSize, 0, file.length()) + } + + def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = { + Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel => + var remaining = length + var pos = offset + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, maxChunkSize) + val chunk = channel.map(FileChannel.MapMode.READ_ONLY, pos, chunkSize) + pos += chunkSize + remaining -= chunkSize + chunks += chunk + } + new ChunkedByteBuffer(chunks.toArray) + } + } +} + /** * Reads data from a ChunkedByteBuffer. * diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala new file mode 100644 index 0000000000000..9622d0ac05368 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.io + +import java.nio.channels.WritableByteChannel + +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.AbstractFileRegion + + +/** + * This exposes a ChunkedByteBuffer as a netty FileRegion, just to allow sending > 2gb in one netty + * message. This is because netty cannot send a ByteBuf > 2g, but it can send a large FileRegion, + * even though the data is not backed by a file. + */ +private[io] class ChunkedByteBufferFileRegion( + private val chunkedByteBuffer: ChunkedByteBuffer, + private val ioChunkSize: Int) extends AbstractFileRegion { + + private var _transferred: Long = 0 + // this duplicates the original chunks, so we're free to modify the position, limit, etc. + private val chunks = chunkedByteBuffer.getChunks() + private val size = chunks.foldLeft(0L) { _ + _.remaining() } + + protected def deallocate: Unit = {} + + override def count(): Long = size + + // this is the "start position" of the overall Data in the backing file, not our current position + override def position(): Long = 0 + + override def transferred(): Long = _transferred + + private var currentChunkIdx = 0 + + def transferTo(target: WritableByteChannel, position: Long): Long = { + assert(position == _transferred) + if (position == size) return 0L + var keepGoing = true + var written = 0L + var currentChunk = chunks(currentChunkIdx) + while (keepGoing) { + while (currentChunk.hasRemaining && keepGoing) { + val ioSize = Math.min(currentChunk.remaining(), ioChunkSize) + val originalLimit = currentChunk.limit() + currentChunk.limit(currentChunk.position() + ioSize) + val thisWriteSize = target.write(currentChunk) + currentChunk.limit(originalLimit) + written += thisWriteSize + if (thisWriteSize < ioSize) { + // the channel did not accept our entire write. We do *not* keep trying -- netty wants + // us to just stop, and report how much we've written. + keepGoing = false + } + } + if (keepGoing) { + // advance to the next chunk (if there are any more) + currentChunkIdx += 1 + if (currentChunkIdx == chunks.size) { + keepGoing = false + } else { + currentChunk = chunks(currentChunkIdx) + } + } + } + _transferred += written + written + } +} diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index db91329c94cb6..0bbaea6b834b8 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -17,6 +17,10 @@ package org.apache.spark.memory; +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.unsafe.memory.MemoryBlock; + import java.io.IOException; public class TestMemoryConsumer extends MemoryConsumer { @@ -43,6 +47,12 @@ void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } + + @VisibleForTesting + public void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0d5c5ea7903e9..faa70f23b0ac6 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -233,6 +233,7 @@ public void writeEmptyIterator() throws Exception { writer.write(Iterators.emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(0, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); @@ -252,6 +253,7 @@ public void writeWithoutSpilling() throws Exception { writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index c145532328514..85ffdca436e14 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -129,7 +129,6 @@ public int compare( final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; - Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); final String str = diff --git a/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java b/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java index 7e9cc70d8651f..0f489fb219010 100644 --- a/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.*; /** - * Java apps can uses both Java-friendly JavaSparkContext and Scala SparkContext. + * Java apps can use both Java-friendly JavaSparkContext and Scala SparkContext. */ public class JavaSparkContextSuite implements Serializable { diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 3990ee1ec326d..5d0ffd92647bc 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -209,10 +209,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex System.gc() assert(ref.get.isEmpty) - // Getting a garbage collected accum should throw error - intercept[IllegalStateException] { - AccumulatorContext.get(accId) - } + // Getting a garbage collected accum should return None. + assert(AccumulatorContext.get(accId).isEmpty) // Getting a normal accumulator. Note: this has to be separate because referencing an // accumulator above in an `assert` would keep it from being garbage collected. diff --git a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala new file mode 100644 index 0000000000000..d49ab4aa7df12 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} +import org.apache.spark.scheduler.BarrierJobAllocationFailed._ +import org.apache.spark.scheduler.DAGScheduler +import org.apache.spark.util.ThreadUtils + +/** + * This test suite covers all the cases that shall fail fast on job submitted that contains one + * of more barrier stages. + */ +class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext { + + private def createSparkContext(conf: Option[SparkConf] = None): SparkContext = { + new SparkContext(conf.getOrElse( + new SparkConf() + .setMaster("local[4]") + .setAppName("test"))) + } + + private def testSubmitJob( + sc: SparkContext, + rdd: RDD[Int], + partitions: Option[Seq[Int]] = None, + message: String): Unit = { + val futureAction = sc.submitJob( + rdd, + (iter: Iterator[Int]) => iter.toArray, + partitions.getOrElse(0 until rdd.partitions.length), + { case (_, _) => return }: (Int, Array[Int]) => Unit, + { return } + ) + + val error = intercept[SparkException] { + ThreadUtils.awaitResult(futureAction, 5 seconds) + }.getCause.getMessage + assert(error.contains(message)) + } + + test("submit a barrier ResultStage that contains PartitionPruningRDD") { + sc = createSparkContext() + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier ShuffleMapStage that contains PartitionPruningRDD") { + sc = createSparkContext() + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage that doesn't contain PartitionPruningRDD") { + sc = createSparkContext() + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .repartition(2) + .barrier() + .mapPartitions(iter => iter) + // Should be able to submit job and run successfully. + val result = rdd.collect().sorted + assert(result === Seq(6, 7, 8, 9, 10)) + } + + test("submit a barrier stage with partial partitions") { + sc = createSparkContext() + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, Some(Seq(1, 3)), + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with union()") { + sc = createSparkContext() + val rdd1 = sc.parallelize(1 to 10, 2) + .barrier() + .mapPartitions(iter => iter) + val rdd2 = sc.parallelize(1 to 20, 2) + val rdd3 = rdd1 + .union(rdd2) + .map(x => x * 2) + // Fail the job on submit because the barrier RDD (rdd1) may be not assigned Task 0. + testSubmitJob(sc, rdd3, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with coalesce()") { + sc = createSparkContext() + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + .coalesce(1) + // Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage + // only launches 1 task. + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage that contains an RDD that depends on multiple barrier RDDs") { + sc = createSparkContext() + val rdd1 = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + val rdd2 = sc.parallelize(11 to 20, 4) + .barrier() + .mapPartitions(iter => iter) + val rdd3 = rdd1 + .zip(rdd2) + .map(x => x._1 + x._2) + testSubmitJob(sc, rdd3, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with zip()") { + sc = createSparkContext() + val rdd1 = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + val rdd2 = sc.parallelize(11 to 20, 4) + val rdd3 = rdd1 + .zip(rdd2) + .map(x => x._1 + x._2) + // Should be able to submit job and run successfully. + val result = rdd3.collect().sorted + assert(result === Seq(12, 14, 16, 18, 20, 22, 24, 26, 28, 30)) + } + + test("submit a barrier ResultStage with dynamic resource allocation enabled") { + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + } + + test("submit a barrier ShuffleMapStage with dynamic resource allocation enabled") { + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + } + + test("submit a barrier ResultStage that requires more slots than current total under local " + + "mode") { + val conf = new SparkConf() + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ShuffleMapStage that requires more slots than current total under " + + "local mode") { + val conf = new SparkConf() + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ResultStage that requires more slots than current total under " + + "local-cluster mode") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ShuffleMapStage that requires more slots than current total under " + + "local-cluster mode") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } +} diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 28ea0c6f0bdba..629a323042ff2 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.Matchers import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} +import org.apache.spark.internal.config import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -154,6 +155,21 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } + private def testCaching(testName: String, conf: SparkConf, storageLevel: StorageLevel): Unit = { + test(testName) { + testCaching(conf, storageLevel) + } + if (storageLevel.replication > 1) { + // also try with block replication as a stream + val uploadStreamConf = new SparkConf() + uploadStreamConf.setAll(conf.getAll) + uploadStreamConf.set(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM, 1L) + test(s"$testName (with replication as stream)") { + testCaching(uploadStreamConf, storageLevel) + } + } + } + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) TestUtils.waitUntilExecutorsUp(sc, 2, 30000) @@ -169,7 +185,10 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val blockManager = SparkEnv.get.blockManager val blockTransfer = blockManager.blockTransferService val serializerManager = SparkEnv.get.serializerManager - blockManager.master.getLocations(blockId).foreach { cmId => + val locations = blockManager.master.getLocations(blockId) + assert(locations.size === storageLevel.replication, + s"; got ${locations.size} replicas instead of ${storageLevel.replication}") + locations.foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString, null) val deserialized = serializerManager.dataDeserializeStream(blockId, @@ -189,8 +208,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - encryptionTest(testName) { conf => - testCaching(conf, storageLevel) + encryptionTestHelper(testName) { case (name, conf) => + testCaching(name, conf, storageLevel) } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 9807d1269e3d4..659ebb60fef86 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -145,6 +145,39 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = { + val conf = new SparkConf() + .setMaster("myDummyLocalExternalClusterManager") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .set("spark.dynamicAllocation.maxExecutors", "15") + .set("spark.dynamicAllocation.minExecutors", "3") + .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString) + .set("spark.executor.cores", cores.toString) + val sc = new SparkContext(conf) + contexts += sc + var manager = sc.executorAllocationManager.get + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 20))) + for (i <- 0 to 5) { + addExecutors(manager) + } + assert(numExecutorsTarget(manager) === expected) + sc.stop() + } + + test("executionAllocationRatio is correctly handled") { + testAllocationRatio(1, 0.5, 10) + testAllocationRatio(1, 1.0/3.0, 7) + testAllocationRatio(2, 1.0/3.0, 4) + testAllocationRatio(1, 0.385, 8) + + // max/min executors capping + testAllocationRatio(1, 1.0, 15) // should be 20 but capped by max + testAllocationRatio(4, 1.0/3.0, 3) // should be 2 but elevated by min + } + + test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get @@ -1343,6 +1376,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def defaultParallelism(): Int = sb.defaultParallelism() + override def maxNumConcurrentTasks(): Int = sb.maxNumConcurrentTasks() + override def killExecutorsOnHost(host: String): Boolean = { false } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 88916488c0def..b705556e54b14 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.util.concurrent.{ExecutorService, TimeUnit} -import scala.collection.Map import scala.collection.mutable import scala.concurrent.Future import scala.concurrent.duration._ @@ -73,6 +72,7 @@ class HeartbeatReceiverSuite sc = spy(new SparkContext(conf)) scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.nodeBlacklist).thenReturn(Predef.Set[String]()) when(scheduler.sc).thenReturn(sc) heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) @@ -241,7 +241,7 @@ class HeartbeatReceiverSuite } === Some(true)) } - private def getTrackedExecutors: Map[String, Long] = { + private def getTrackedExecutors: collection.Map[String, Long] = { // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). @@ -272,7 +272,7 @@ private class FakeSchedulerBackend( protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty[String])) + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty)) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 1dd89bcbe36bc..05aaaa11451b4 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -29,7 +29,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self override def beforeAll() { super.beforeAll() - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) + InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) } override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 50b8ea754d8d9..e79739692fe13 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L))) + Array(1000L, 10000L), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L))) + Array(10000L, 1000L), 10)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array(compressedSize1000, compressedSize10000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000), 10)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array(compressedSize1000, compressedSize1000, compressedSize1000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000, compressedSize1000), 10)) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,9 +145,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) + BlockManagerId("a", "hostA", 1000), Array(1000L), 10)) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L))) + Array(3L), 1)) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -260,7 +260,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -298,4 +298,33 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("zero-sized blocks should be excluded when getMapSizesByExecutorId") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + + val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000), 1)) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0), 1)) + assert(tracker.containsShuffle(10)) + assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + Seq( + (BlockManagerId("a", "hostA", 1000), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + (BlockManagerId("b", "hostB", 1000), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + ) + ) + + tracker.unregisterShuffle(10) + tracker.stop() + rpcEnv.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 8eabc2b3cb958..5dbfc5c10a6f8 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark import java.io.File +import java.util.UUID import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.alias.{CredentialProvider, CredentialProviderFactory} import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.SparkConfWithEnv @@ -40,6 +43,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { .toSet val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -49,7 +53,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) conf.set("spark.ssl.protocol", "TLSv1.2") - val opts = SSLOptions.parse(conf, "spark.ssl") + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl") assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -70,6 +74,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -80,8 +85,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -103,6 +108,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.ui.port", "4242") @@ -117,8 +123,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.port === Some(4242)) @@ -139,14 +145,71 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConfWithEnv(Map( "ENV1" -> "val1", "ENV2" -> "val2")) + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", "${env:ENV1}") conf.set("spark.ssl.trustStore", "${env:ENV2}") - val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) assert(opts.keyStore === Some(new File("val1"))) assert(opts.trustStore === Some(new File("val2"))) } + test("get password from Hadoop credential provider") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + val hadoopConf = new Configuration() + val tmpPath = s"localjceks://file${sys.props("java.io.tmpdir")}/test-" + + s"${UUID.randomUUID().toString}.jceks" + val provider = createCredentialProvider(tmpPath, hadoopConf) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + storePassword(provider, "spark.ssl.keyStorePassword", "password") + storePassword(provider, "spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + storePassword(provider, "spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + private def createCredentialProvider(tmpPath: String, conf: Configuration): CredentialProvider = { + conf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, tmpPath) + + val provider = CredentialProviderFactory.getProviders(conf).get(0) + if (provider == null) { + throw new IllegalStateException(s"Fail to get credential provider with path $tmpPath") + } + + provider + } + + private def storePassword( + provider: CredentialProvider, + passwordKey: String, + password: String): Unit = { + provider.createCredentialEntry(passwordKey, password.toCharArray) + provider.flush() + } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index ced5a06516f75..456f97b535ef6 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -208,7 +208,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect() results should have length (1) - // substracted rdd return results as Tuple2 + // subtracted rdd return results as Tuple2 results(0) should be ((3, 33)) } @@ -391,6 +391,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(mapOutput2.isDefined) assert(mapOutput1.get.location === mapOutput2.get.location) assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput) // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index bff808eb540ac..0d06b02e74e34 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -339,6 +339,38 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } + val defaultIllegalValue = "SomeIllegalValue" + val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map( + "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)), + "getTimeAsSeconds with default" -> (_.getTimeAsSeconds(_, defaultIllegalValue)), + "getTimeAsMs" -> (_.getTimeAsMs(_)), + "getTimeAsMs with default" -> (_.getTimeAsMs(_, defaultIllegalValue)), + "getSizeAsBytes" -> (_.getSizeAsBytes(_)), + "getSizeAsBytes with default string" -> (_.getSizeAsBytes(_, defaultIllegalValue)), + "getSizeAsBytes with default long" -> (_.getSizeAsBytes(_, 0L)), + "getSizeAsKb" -> (_.getSizeAsKb(_)), + "getSizeAsKb with default" -> (_.getSizeAsKb(_, defaultIllegalValue)), + "getSizeAsMb" -> (_.getSizeAsMb(_)), + "getSizeAsMb with default" -> (_.getSizeAsMb(_, defaultIllegalValue)), + "getSizeAsGb" -> (_.getSizeAsGb(_)), + "getSizeAsGb with default" -> (_.getSizeAsGb(_, defaultIllegalValue)), + "getInt" -> (_.getInt(_, 0)), + "getLong" -> (_.getLong(_, 0L)), + "getDouble" -> (_.getDouble(_, 0.0)), + "getBoolean" -> (_.getBoolean(_, false)) + ) + + illegalValueTests.foreach { case (name, getValue) => + test(s"SPARK-24337: $name throws an useful error message with key name") { + val key = "SomeKey" + val conf = new SparkConf() + conf.set(key, "SomeInvalidValue") + val thrown = intercept[IllegalArgumentException] { + getValue(conf, key) + } + assert(thrown.getMessage.contains(key)) + } + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index ce9f2be1c02dd..e1666a35271d3 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -627,6 +627,51 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(exc.getCause() != null) stream.close() } + + test("support barrier execution mode under local mode") { + val conf = new SparkConf().setAppName("test").setMaster("local[2]") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (context.getTaskInfos().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${context.getTaskInfos().length}.") + } + context.barrier() + it + } + rdd2.collect() + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + + test("support barrier execution mode under local-cluster mode") { + val conf = new SparkConf() + .setMaster("local-cluster[3, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (context.getTaskInfos().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${context.getTaskInfos().length}.") + } + context.barrier() + it + } + rdd2.collect() + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala deleted file mode 100644 index ab24a76e20a30..0000000000000 --- a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy - -import java.security.PrivilegedExceptionAction - -import scala.util.Random - -import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.fs.permission.{FsAction, FsPermission} -import org.apache.hadoop.security.UserGroupInformation -import org.scalatest.Matchers - -import org.apache.spark.SparkFunSuite - -class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { - test("check file permission") { - import FsAction._ - val testUser = s"user-${Random.nextInt(100)}" - val testGroups = Array(s"group-${Random.nextInt(100)}") - val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) - - testUgi.doAs(new PrivilegedExceptionAction[Void] { - override def run(): Void = { - val sparkHadoopUtil = new SparkHadoopUtil - - // If file is owned by user and user has access permission - var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by user but user has no access permission - status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - val otherUser = s"test-${Random.nextInt(100)}" - val otherGroup = s"test-${Random.nextInt(100)}" - - // If file is owned by user's group and user's group has access permission - status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by user's group but user's group has no access permission - status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - // If file is owned by other user and this user has access permission - status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by other user but this user has no access permission - status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - null - } - }) - } - - private def fileStatus( - owner: String, - group: String, - userAction: FsAction, - groupAction: FsAction, - otherAction: FsAction): FileStatus = { - new FileStatus(0L, - false, - 0, - 0L, - 0L, - 0L, - new FsPermission(userAction, groupAction, otherAction), - owner, - group, - null) - } -} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7451e07b25a1f..f829fecc30840 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -180,6 +180,26 @@ class SparkSubmitSuite appArgs.toString should include ("thequeue") } + test("SPARK-24241: do not fail fast if executor num is 0 when dynamic allocation is enabled") { + val clArgs1 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=true", + "thejar.jar") + new SparkSubmitArguments(clArgs1) + + val clArgs2 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=false", + "thejar.jar") + + val e = intercept[SparkException](new SparkSubmitArguments(clArgs2)) + assert(e.getMessage.contains("Number of executors must be a positive number")) + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", @@ -751,9 +771,13 @@ class SparkSubmitSuite PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val pyFile1 = File.createTempFile("file1", ".py", tmpDir) + val pyFile2 = File.createTempFile("file2", ".py", tmpDir) val writer4 = new PrintWriter(f4) - val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py" + val remotePyFiles = s"s3a://${pyFile1.getAbsolutePath},s3a://${pyFile2.getAbsolutePath}" writer4.println("spark.submit.pyFiles " + remotePyFiles) writer4.close() val clArgs4 = Seq( @@ -763,7 +787,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4, conf = Some(hadoopConf)) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -971,20 +995,24 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = true) } test("force download from blacklisted schemes") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http")) + } + + test("force download for all the schemes") { + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*")) } private def testRemoteResources( enableHttpFs: Boolean, - blacklistHttpFs: Boolean): Unit = { + blacklistSchemes: Seq[String] = Nil): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) if (enableHttpFs) { @@ -1001,8 +1029,8 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" - val forceDownloadArgs = if (blacklistHttpFs) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + val forceDownloadArgs = if (blacklistSchemes.nonEmpty) { + Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}") } else { Nil } @@ -1020,14 +1048,19 @@ class SparkSubmitSuite val jars = conf.get("spark.yarn.dist.jars").split(",").toSet - // The URI of remote S3 resource should still be remote. - assert(jars.contains(tmpS3JarPath)) + def isSchemeBlacklisted(scheme: String) = { + blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme) + } + + if (!isSchemeBlacklisted("s3")) { + assert(jars.contains(tmpS3JarPath)) + } - if (enableHttpFs && !blacklistHttpFs) { + if (enableHttpFs && blacklistSchemes.isEmpty) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) - } else { + } else if (!enableHttpFs || isSchemeBlacklisted("http")) { // If Http FS is not supported by yarn service, or http scheme is configured to be force // downloading, the URI of remote http resource should be changed to a local one. val jarName = new File(tmpHttpJar.toURI).getName @@ -1073,6 +1106,44 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + test("support --py-files/spark.submit.pyFiles in non pyspark application") { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + + val tmpDir = Utils.createTempDir() + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs = new SparkSubmitArguments(args) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) + + conf.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.submit.pyFiles") should (startWith("/")) + + // Verify "spark.submit.pyFiles" + val args1 = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--conf", s"spark.submit.pyFiles=s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs1 = new SparkSubmitArguments(args1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1, conf = Some(hadoopConf)) + + conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf1.get("spark.submit.pyFiles") should (startWith("/")) + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index eb8c203ae7751..a0f09891787e0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -256,4 +256,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mydep") >= 0, "should find dependency") } } + + test("SPARK-10878: test resolution files cleaned after resolving artifact") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + + IvyTestUtils.withRepository(main, None, None) { repo => + val ivySettings = SparkSubmitUtils.buildIvySettings(Some(repo), Some(tempIvyPath)) + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + ivySettings, + isTest = true) + val r = """.*org.apache.spark-spark-submit-parent-.*""".r + assert(!ivySettings.getDefaultCache.listFiles.map(_.getName) + .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned") + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 77b239489d489..b4eba755eccbf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -29,9 +29,11 @@ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.security.AccessControlException import org.json4s.jackson.JsonMethods._ -import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, spy, verify} +import org.mockito.ArgumentMatcher +import org.mockito.Matchers.{any, argThat} +import org.mockito.Mockito.{doThrow, mock, spy, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -818,6 +820,42 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("SPARK-24948: blacklist files we don't have read permission on") { + val clock = new ManualClock(1533132471) + val provider = new FsHistoryProvider(createTestConf(), clock) + val accessDenied = newLogFile("accessDenied", None, inProgress = false) + writeFile(accessDenied, true, None, + SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None)) + val accessGranted = newLogFile("accessGranted", None, inProgress = false) + writeFile(accessGranted, true, None, + SparkListenerApplicationStart("accessGranted", Some("accessGranted"), 1L, "test", None), + SparkListenerApplicationEnd(5L)) + val mockedFs = spy(provider.fs) + doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open( + argThat(new ArgumentMatcher[Path]() { + override def matches(path: Any): Boolean = { + path.asInstanceOf[Path].getName.toLowerCase == "accessdenied" + } + })) + val mockedProvider = spy(provider) + when(mockedProvider.fs).thenReturn(mockedFs) + updateAndCheck(mockedProvider) { list => + list.size should be(1) + } + writeFile(accessDenied, true, None, + SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None), + SparkListenerApplicationEnd(5L)) + // Doing 2 times in order to check the blacklist filter too + updateAndCheck(mockedProvider) { list => + list.size should be(1) + } + val accessDeniedPath = new Path(accessDenied.getPath) + assert(mockedProvider.isBlacklisted(accessDeniedPath)) + clock.advance(24 * 60 * 60 * 1000 + 1) // add a bit more than 1d + mockedProvider.cleanLogs() + assert(!mockedProvider.isBlacklisted(accessDeniedPath)) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87f12f303cd5e..11b29121739a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -36,6 +36,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito._ import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} @@ -281,6 +282,29 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } + test("automatically retrieve uiRoot from request through Knox") { + assert(sys.props.get("spark.ui.proxyBase").isEmpty, + "spark.ui.proxyBase is defined but it should not for this UT") + assert(sys.env.get("APPLICATION_WEB_PROXY_BASE").isEmpty, + "APPLICATION_WEB_PROXY_BASE is defined but it should not for this UT") + val page = new HistoryPage(server) + val requestThroughKnox = mock[HttpServletRequest] + val knoxBaseUrl = "/gateway/default/sparkhistoryui" + when(requestThroughKnox.getHeader("X-Forwarded-Context")).thenReturn(knoxBaseUrl) + val responseThroughKnox = page.render(requestThroughKnox) + + val urlsThroughKnox = responseThroughKnox \\ "@href" map (_.toString) + val siteRelativeLinksThroughKnox = urlsThroughKnox filter (_.startsWith("/")) + all (siteRelativeLinksThroughKnox) should startWith (knoxBaseUrl) + + val directRequest = mock[HttpServletRequest] + val directResponse = page.render(directRequest) + + val directUrls = directResponse \\ "@href" map (_.toString) + val directSiteRelativeLinks = directUrls filter (_.startsWith("/")) + all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) + } + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) @@ -296,6 +320,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("/version api endpoint") { + val response = getUrl("version") + assert(response.contains(SPARK_VERSION)) + } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index ce212a7513310..e3fe2b696aa1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,10 +17,19 @@ package org.apache.spark.deploy.worker +import java.util.concurrent.atomic.AtomicBoolean +import java.util.function.Supplier + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.deploy.{Command, ExecutorState, ExternalShuffleService} import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} import org.apache.spark.deploy.master.DriverState import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -29,6 +38,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { import org.apache.spark.deploy.DeployTestUtils._ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleService: ExternalShuffleService = _ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -36,15 +47,21 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private var _worker: Worker = _ - private def makeWorker(conf: SparkConf): Worker = { + private def makeWorker( + conf: SparkConf, + shuffleServiceSupplier: Supplier[ExternalShuffleService] = null): Worker = { assert(_worker === null, "Some Worker's RpcEnv is leaked in tests") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr) _worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, securityMgr) + "Worker", "/tmp", conf, securityMgr, shuffleServiceSupplier) _worker } + before { + MockitoAnnotations.initMocks(this) + } + after { if (_worker != null) { _worker.rpcEnv.shutdown() @@ -194,4 +211,36 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { assert(worker.finishedDrivers.size === expectedValue) } } + + test("cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=true") { + testCleanupFilesWithConfig(true) + } + + test("don't cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=false") { + testCleanupFilesWithConfig(false) + } + + private def testCleanupFilesWithConfig(value: Boolean) = { + val conf = new SparkConf().set("spark.storage.cleanupFilesAfterExecutorExit", value.toString) + + val cleanupCalled = new AtomicBoolean(false) + when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + cleanupCalled.set(true) + } + }) + val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] { + override def get: ExternalShuffleService = shuffleService + } + val worker = makeWorker(conf, externalShuffleServiceSupplier) + // initialize workers + for (i <- 0 until 10) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(cleanupCalled.get() == value) + } } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 1a7bebe2c53cd..77a7668d3a1d1 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -275,6 +275,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug executorId = "", name = "", index = 0, + partitionId = 0, addedFiles = Map[String, Long](), addedJars = Map[String, Long](), properties = new Properties, diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala new file mode 100644 index 0000000000000..817dc082b7d38 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Tests the correctness of + * [[org.apache.spark.input.WholeTextFileInputFormat WholeTextFileInputFormat]]. A temporary + * directory containing files is created as fake input which is deleted in the end. + */ +class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { + private var sc: SparkContext = _ + + override def beforeAll() { + super.beforeAll() + val conf = new SparkConf() + sc = new SparkContext("local", "test", conf) + } + + override def afterAll() { + try { + sc.stop() + } finally { + super.afterAll() + } + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], + compress: Boolean) = { + val path = s"${inputDir.toString}/$fileName" + val out = new DataOutputStream(new FileOutputStream(path)) + out.write(contents, 0, contents.length) + out.close() + } + + test("for small files minimum split size per node and per rack should be less than or equal to " + + "maximum split size.") { + var dir : File = null; + try { + dir = Utils.createTempDir() + logInfo(s"Local disk address is ${dir.toString}.") + + // Set the minsize per node and rack to be larger than the size of the input file. + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.node", 123456) + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.rack", 123456) + + WholeTextFileInputFormatSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, false) + } + // ensure spark job runs successfully without exceptions from the CombineFileInputFormat + assert(sc.wholeTextFiles(dir.toString).count == 3) + } finally { + Utils.deleteRecursively(dir) + } + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileInputFormatSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala new file mode 100644 index 0000000000000..a6b0654204f34 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io + +import java.nio.ByteBuffer +import java.nio.channels.WritableByteChannel + +import scala.util.Random + +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.util.io.ChunkedByteBuffer + +class ChunkedByteBufferFileRegionSuite extends SparkFunSuite with MockitoSugar + with BeforeAndAfterEach { + + override protected def beforeEach(): Unit = { + super.beforeEach() + val conf = new SparkConf() + val env = mock[SparkEnv] + SparkEnv.set(env) + when(env.conf).thenReturn(conf) + } + + override protected def afterEach(): Unit = { + SparkEnv.set(null) + } + + private def generateChunkedByteBuffer(nChunks: Int, perChunk: Int): ChunkedByteBuffer = { + val bytes = (0 until nChunks).map { chunkIdx => + val bb = ByteBuffer.allocate(perChunk) + (0 until perChunk).foreach { idx => + bb.put((chunkIdx * perChunk + idx).toByte) + } + bb.position(0) + bb + }.toArray + new ChunkedByteBuffer(bytes) + } + + test("transferTo can stop and resume correctly") { + SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 9L) + val cbb = generateChunkedByteBuffer(4, 10) + val fileRegion = cbb.toNetty + + val targetChannel = new LimitedWritableByteChannel(40) + + var pos = 0L + // write the fileregion to the channel, but with the transfer limited at various spots along + // the way. + + // limit to within the first chunk + targetChannel.acceptNBytes = 5 + pos = fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 5) + + // a little bit further within the first chunk + targetChannel.acceptNBytes = 2 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 7) + + // past the first chunk, into the 2nd + targetChannel.acceptNBytes = 6 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 13) + + // right to the end of the 2nd chunk + targetChannel.acceptNBytes = 7 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 20) + + // rest of 2nd chunk, all of 3rd, some of 4th + targetChannel.acceptNBytes = 15 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 35) + + // now till the end + targetChannel.acceptNBytes = 5 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 40) + + // calling again at the end should be OK + targetChannel.acceptNBytes = 20 + fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 40) + } + + test(s"transfer to with random limits") { + val rng = new Random() + val seed = System.currentTimeMillis() + logInfo(s"seed = $seed") + rng.setSeed(seed) + val chunkSize = 1e4.toInt + SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, rng.nextInt(chunkSize).toLong) + + val cbb = generateChunkedByteBuffer(50, chunkSize) + val fileRegion = cbb.toNetty + val transferLimit = 1e5.toInt + val targetChannel = new LimitedWritableByteChannel(transferLimit) + while (targetChannel.pos < cbb.size) { + val nextTransferSize = rng.nextInt(transferLimit) + targetChannel.acceptNBytes = nextTransferSize + fileRegion.transferTo(targetChannel, targetChannel.pos) + } + assert(0 === fileRegion.transferTo(targetChannel, targetChannel.pos)) + } + + /** + * This mocks a channel which only accepts a limited number of bytes at a time. It also verifies + * the written data matches our expectations as the data is received. + */ + private class LimitedWritableByteChannel(maxWriteSize: Int) extends WritableByteChannel { + val bytes = new Array[Byte](maxWriteSize) + var acceptNBytes = 0 + var pos = 0 + + override def write(src: ByteBuffer): Int = { + val length = math.min(acceptNBytes, src.remaining()) + src.get(bytes, 0, length) + acceptNBytes -= length + // verify we got the right data + (0 until length).foreach { idx => + assert(bytes(idx) === (pos + idx).toByte, s"; wrong data at ${pos + idx}") + } + pos += length + length + } + + override def isOpen: Boolean = true + + override def close(): Unit = {} + } + +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 3b798e36b0499..ff117b1c21cb1 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.util.io.ChunkedByteBuffer -class ChunkedByteBufferSuite extends SparkFunSuite { +class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { test("no chunks") { val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) @@ -33,7 +34,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(emptyChunkedByteBuffer.getChunks().isEmpty) assert(emptyChunkedByteBuffer.toArray === Array.empty) assert(emptyChunkedByteBuffer.toByteBuffer.capacity() === 0) - assert(emptyChunkedByteBuffer.toNetty.capacity() === 0) + assert(emptyChunkedByteBuffer.toNetty.count() === 0) emptyChunkedByteBuffer.toInputStream(dispose = false).close() emptyChunkedByteBuffer.toInputStream(dispose = true).close() } @@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(chunkedByteBuffer.getChunks().head.position() === 0) } + test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") { + try { + sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024))) + val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt) + chunkedByteBuffer.writeFully(byteArrayWritableChannel) + assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size) + } finally { + sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE) + } + } + test("toArray()") { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala new file mode 100644 index 0000000000000..d57ea4d5501e3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + +class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext { + + test("create an RDDBarrier") { + val rdd = sc.parallelize(1 to 10, 4) + assert(rdd.isBarrier() === false) + + val rdd2 = rdd.barrier().mapPartitions(iter => iter) + assert(rdd2.isBarrier() === true) + } + + test("create an RDDBarrier in the middle of a chain of RDDs") { + val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2) + val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1)) + assert(rdd2.isBarrier() === true) + } + + test("RDDBarrier with shuffle") { + val rdd = sc.parallelize(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions(iter => iter).repartition(2) + assert(rdd2.isBarrier() === false) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 191c61250ce21..b143a468a1baf 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -154,6 +154,16 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("SPARK-23778: empty RDD in union should not produce a UnionRDD") { + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val emptyRDD = sc.emptyRDD[(Int, Boolean)] + val unionRDD = sc.union(emptyRDD, rddWithPartitioner) + assert(unionRDD.isInstanceOf[PartitionerAwareUnionRDD[_]]) + val unionAllEmptyRDD = sc.union(emptyRDD, emptyRDD) + assert(unionAllEmptyRDD.isInstanceOf[UnionRDD[_]]) + assert(unionAllEmptyRDD.collect().isEmpty) + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) @@ -433,7 +443,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") } - test("coalesced RDDs with partial locality") { + test("coalesced RDDs with partial locality") { // Make an RDD that has some locality preferences and some without. This can happen // with UnionRDD val data = sc.makeRDD((1 to 9).map(i => { @@ -836,6 +846,28 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8))) } + test("cartesian on empty RDD") { + val a = sc.emptyRDD[Int] + val b = sc.parallelize(1 to 3) + val cartesian_result = Array.empty[(Int, Int)] + assert(a.cartesian(a).collect().toList === cartesian_result) + assert(a.cartesian(b).collect().toList === cartesian_result) + assert(b.cartesian(a).collect().toList === cartesian_result) + } + + test("cartesian on non-empty RDDs") { + val a = sc.parallelize(1 to 3) + val b = sc.parallelize(2 to 4) + val c = sc.parallelize(1 to 1) + val a_cartesian_b = + Array((1, 2), (1, 3), (1, 4), (2, 2), (2, 3), (2, 4), (3, 2), (3, 3), (3, 4)) + val a_cartesian_c = Array((1, 1), (2, 1), (3, 1)) + val c_cartesian_a = Array((1, 1), (1, 2), (1, 3)) + assert(a.cartesian[Int](b).collect().toList.sorted === a_cartesian_b) + assert(a.cartesian[Int](c).collect().toList.sorted === a_cartesian_c) + assert(c.cartesian[Int](a).collect().toList.sorted === c_cartesian_a) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) @@ -1047,7 +1079,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) { private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty override def compute(p: Partition, c: TaskContext): Iterator[T] = Iterator.empty - override def getPartitions: Array[Partition] = Array.empty + override def getPartitions: Array[Partition] = Array(new Partition { + override def index: Int = 0 + }) override def getDependencies: Seq[Dependency[_]] = mutableDependencies def addDependency(dep: Dependency[_]) { mutableDependencies += dep diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala new file mode 100644 index 0000000000000..36dd620a56853 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Random + +import org.apache.spark._ + +class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { + + test("global sync by barrier() call") { + val conf = new SparkConf() + // Init local cluster here so each barrier task runs in a separated process, thus `barrier()` + // call is actually useful. + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + Seq(System.currentTimeMillis()).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish global sync within a short time slot. + assert(times.max - times.min <= 1000) + } + + test("support multiple barrier() call within a single task") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time between two global syncs. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + + test("throw exception on barrier() call timeout") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Task 3 shall sleep 2000ms to ensure barrier() call timeout + if (context.taskAttemptId == 3) { + Thread.sleep(2000) + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } + + test("throw exception if barrier() call doesn't happen on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + if (context.taskAttemptId != 0) { + context.barrier() + } + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } + + test("throw exception if the number of barrier() calls are not the same on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + try { + if (context.taskAttemptId == 0) { + // Due to some non-obvious reason, the code can trigger an Exception and skip the + // following statements within the try ... catch block, including the first barrier() + // call. + throw new SparkException("test") + } + context.barrier() + } catch { + case e: Exception => // Do nothing + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 04cccc67e328e..80c9c6f0422a8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,10 +17,18 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicBoolean + +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually + import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.rdd.RDD import org.apache.spark.util.{RpcUtils, SerializableBuffer} -class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext + with Eventually { test("serialized task larger than max RPC message size") { val conf = new SparkConf @@ -38,4 +46,83 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo assert(smaller.size === 4) } + test("compute max number of concurrent tasks can be launched") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + assert(sc.maxNumConcurrentTasks() == 12) + } + + test("compute max number of concurrent tasks can be launched when spark.task.cpus > 1") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + // Each executor can only launch one task since `spark.task.cpus` is 2. + assert(sc.maxNumConcurrentTasks() == 4) + } + + test("compute max number of concurrent tasks can be launched when some executors are busy") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + val rdd = sc.parallelize(1 to 10, 4).mapPartitions { iter => + Thread.sleep(5000) + iter + } + var taskStarted = new AtomicBoolean(false) + var taskEnded = new AtomicBoolean(false) + val listener = new SparkListener() { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + taskStarted.set(true) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEnded.set(true) + } + } + + try { + sc.addSparkListener(listener) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + + // Submit a job to trigger some tasks on active executors. + testSubmitJob(sc, rdd) + + eventually(timeout(10.seconds)) { + // Ensure some tasks have started and no task finished, so some executors must be busy. + assert(taskStarted.get() == true) + assert(taskEnded.get() == false) + // Assert we count in slots on both busy and free executors. + assert(sc.maxNumConcurrentTasks() == 4) + } + } finally { + sc.removeSparkListener(listener) + } + } + + private def testSubmitJob(sc: SparkContext, rdd: RDD[Int]): Unit = { + sc.submitJob( + rdd, + (iter: Iterator[Int]) => iter.toArray, + 0 until rdd.partitions.length, + { case (_, _) => return }: (Int, Array[Int]) => Unit, + { return } + ) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8b6ec37625eec..56ba23c38af7f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -131,6 +131,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -213,7 +215,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } private def init(testConf: SparkConf): Unit = { - sc = new SparkContext("local", "DAGSchedulerSuite", testConf) + sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf) sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() @@ -421,17 +423,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // make sure our test setup is correct @@ -629,6 +631,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskId: Long, interruptThread: Boolean, reason: String): Boolean = { throw new UnsupportedOperationException } + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( @@ -1055,6 +1061,91 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(sparkListener.failedStages.size == 1) } + test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) + + scheduler.resubmitFailedStages() + // Complete the map stage. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2) + + // Complete the result stage. + completeNextResultStageWithSuccess(1, 1) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assertDataStructuresEmpty() + } + + test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by TaskKilled") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(1))) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) + + scheduler.resubmitFailedStages() + // Complete the map stage. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2) + + // Complete the result stage. + completeNextResultStageWithSuccess(1, 0) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assertDataStructuresEmpty() + } + + test("Fail the job if a barrier ResultTask failed") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + .barrier() + .mapPartitions(iter => iter) + submit(reduceRdd, Array(0, 1)) + + // Complete the map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostA", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // The first ResultTask fails + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + TaskKilled("test"), + null)) + + // Assert the stage has been cancelled. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(failure.getMessage.startsWith("Job aborted due to stage failure: Could not recover " + + "from a failed barrier ResultStage.")) + } + /** * This tests the case where another FetchFailed comes in while the map stage is getting * re-run. @@ -1852,7 +1943,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } - test("accumulators are updated on exception failures") { + test("accumulators are updated on exception failures and task killed") { val acc1 = AccumulatorSuite.createLongAccum("ingenieur") val acc2 = AccumulatorSuite.createLongAccum("boulanger") val acc3 = AccumulatorSuite.createLongAccum("agriculteur") @@ -1868,15 +1959,24 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accUpdate3 = new LongAccumulator accUpdate3.metadata = acc3.metadata accUpdate3.setValue(18) - val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3) - val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo) + + val accumUpdates1 = Seq(accUpdate1, accUpdate2) + val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo) val exceptionFailure = new ExceptionFailure( new SparkException("fondue?"), - accumInfo).copy(accums = accumUpdates) + accumInfo1).copy(accums = accumUpdates1) submit(new MyRDD(sc, 1, Nil), Array(0)) runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(AccumulatorContext.get(acc1.id).get.value === 15L) assert(AccumulatorContext.get(acc2.id).get.value === 13L) + + val accumUpdates2 = Seq(accUpdate3) + val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo) + + val taskKilled = new TaskKilled( "test", accumInfo2, accums = accumUpdates2) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result")) + assert(AccumulatorContext.get(acc3.id).get.value === 18L) } @@ -2313,9 +2413,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Runs a job that encounters a single fetch failure but succeeds on the second attempt def runJobWithTemporaryFetchFailure: Unit = { - object FailThisAttempt { - val _fail = new AtomicBoolean(true) - } val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() val shuffleHandle = rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle @@ -2395,7 +2492,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(1).tasks(1), Success, makeMapStatus("hostA", 2))) - // Both tasks in rddB should be resubmitted, because none of them has succeeded truely. + // Both tasks in rddB should be resubmitted, because none of them has succeeded truly. // Complete the task(stageId=1, stageAttemptId=1, partitionId=0) successfully. // Task(stageId=1, stageAttemptId=1, partitionId=1) of this new active stage attempt // is still running. @@ -2451,6 +2548,85 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } + test("Barrier task failures from the same stage attempt don't trigger multiple stage retries") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // The first map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + } + + test("Barrier task failures from a previous stage attempt don't trigger stage retry") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // The first map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + + // The second map task failure doesn't trigger stage retry. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. @@ -2497,6 +2673,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accumUpdates = reason match { case Success => task.metrics.accumulators() case ef: ExceptionFailure => ef.accums + case tk: TaskKilled => tk.accums case _ => Seq.empty } CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) @@ -2505,8 +2682,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) } + +object FailThisAttempt { + val _fail = new AtomicBoolean(true) +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index a4e4ea7cd2894..b4705914b999b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -69,6 +69,7 @@ private class DummySchedulerBackend extends SchedulerBackend { def stop() {} def reviveOffers() {} def defaultParallelism(): Int = 1 + def maxNumConcurrentTasks(): Int = 0 } private class DummyTaskScheduler extends TaskScheduler { @@ -81,6 +82,8 @@ private class DummyTaskScheduler extends TaskScheduler { override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 109d4a0a870b8..b29d32f7b35c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -27,8 +27,10 @@ class FakeTask( partitionId: Int, prefLocs: Seq[TaskLocation] = Nil, serializedTaskMetrics: Array[Byte] = - SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) - extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics) { + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), + isBarrier: Boolean = false) + extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics, + isBarrier = isBarrier) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs @@ -74,4 +76,22 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createBarrierTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createBarrierTaskSet(numTasks, stageId = 0, stageAttempId = 0, prefLocs: _*) + } + + def createBarrierTaskSet( + numTasks: Int, + stageId: Int, + stageAttempId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true) + } + new TaskSet(tasks, stageId, stageAttempId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 2155a0f2b6c21..555e48bd28aa0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -60,7 +60,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -74,7 +74,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,7 +86,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -108,7 +108,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -164,7 +164,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 1) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) @@ -188,4 +188,32 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } + + test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") { + val conf = new SparkConf() + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + val sizes = Array.fill[Long](500)(150L) + // Test default value + val status = MapStatus(null, sizes, 1) + assert(status.isInstanceOf[CompressedMapStatus]) + // Test Non-positive values + for (s <- -1 to 0) { + assertThrows[IllegalArgumentException] { + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes, 1) + } + } + // Test positive values + Seq(1, 100, 499, 500, 501).foreach { s => + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes, 1) + if(sizes.length > s) { + assert(status.isInstanceOf[HighlyCompressedMapStatus]) + } else { + assert(status.isInstanceOf[CompressedMapStatus]) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 03b1903902491..158c9eb75f2b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -153,7 +154,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Job should not complete if all commits are denied") { // Create a mock OutputCommitCoordinator that denies all attempts to commit doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( - Matchers.any(), Matchers.any(), Matchers.any()) + Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any()) val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) def resultHandler(x: Int, y: Unit): Unit = {} val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, @@ -169,45 +170,106 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 2 val authorizedCommitter: Int = 3 val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage, maxPartitionId = 2) - assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter)) // The non-authorized committer fails - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer - assert( - outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 2)) // There can only be one authorized committer - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) - } - - test("Duplicate calls to canCommit from the authorized committer gets idempotent responses.") { - val rdd = sc.parallelize(Seq(1), 1) - sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, - 0 until rdd.partitions.size) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 3)) } test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 1 val failedAttempt: Int = 0 outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) - outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = failedAttempt, reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) - assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) - assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt + 1)) + } + + test("SPARK-24589: Differentiate tasks from different stage attempts") { + var stage = 1 + val taskAttempt = 1 + val partition = 1 + + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(!outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Fail the task in the first attempt, the task in the second attempt should succeed. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit, + // then fail the 1st attempt and make sure the 4th one can commit again. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 2, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt)) + } + + test("SPARK-24589: Make sure stage state is cleaned up") { + // Normal application without stage failures. + sc.parallelize(1 to 100, 100) + .map { i => (i % 10, i) } + .reduceByKey(_ + _) + .collect() + + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + + // Force failures in a few tasks so that a stage is retried. Collect the ID of the failing + // stage so that we can check the state of the output committer. + val retriedStage = sc.parallelize(1 to 100, 10) + .map { i => (i % 10, i) } + .reduceByKey { case (_, _) => + val ctx = TaskContext.get() + if (ctx.stageAttemptNumber() == 0) { + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1, + new Exception("Failure for test.")) + } else { + ctx.stageId() + } + } + .collect() + .map { case (k, v) => v } + .toSet + + assert(retriedStage.size === 1) + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + verify(sc.env.outputCommitCoordinator, times(2)) + .stageStart(Matchers.eq(retriedStage.head), Matchers.any()) + verify(sc.env.outputCommitCoordinator).stageEnd(Matchers.eq(retriedStage.head)) } } @@ -243,16 +305,6 @@ private case class OutputCommitFunctions(tempDirPath: String) { if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) } - // Receiver should be idempotent for AskPermissionToCommitOutput - def callCanCommitMultipleTimes(iter: Iterator[Int]): Unit = { - val ctx = TaskContext.get() - val canCommit1 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - val canCommit2 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - assert(canCommit1 && canCommit2) - } - private def runCommitWithProvidedCommitter( ctx: TaskContext, iter: Iterator[Int], diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 75ea409e16b4b..cea7f173c8f2f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -385,6 +385,8 @@ private[spark] abstract class MockBackend( }.toIndexedSeq } + override def maxNumConcurrentTasks(): Int = 0 + /** * This is called by the scheduler whenever it has tasks it would like to schedule, when a tasks * completes (which will be in a result-getter thread), and by the reviveOffers thread for delay diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index da6ecb82c7e42..6ffd1e84f7adb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.Semaphore import scala.collection.JavaConverters._ @@ -294,10 +295,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) - // just to make sure some of the tasks take a noticeable amount of time + // just to make sure some of the tasks and their deserialization take a noticeable + // amount of time + val slowDeserializable = new SlowDeserializable val w = { i: Int => if (i == 0) { Thread.sleep(100) + slowDeserializable.use() } i } @@ -485,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) } + Seq(true, false).foreach { throwInterruptedException => + val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted" + test(s"interrupt within listener is handled correctly: $suffix") { + val conf = new SparkConf(false) + .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5) + val bus = new LiveListenerBus(conf) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val interruptingListener1 = new InterruptingListener(throwInterruptedException) + val interruptingListener2 = new InterruptingListener(throwInterruptedException) + bus.addToSharedQueue(counter1) + bus.addToSharedQueue(interruptingListener1) + bus.addToStatusQueue(counter2) + bus.addToEventLogQueue(interruptingListener2) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 2) + + bus.start(mockSparkContext, mockMetricsSystem) + + // after we post one event, both interrupting listeners should get removed, and the + // event log queue should be removed + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 0) + assert(counter1.count === 1) + assert(counter2.count === 1) + + // posting more events should be fine, they'll just get processed from the OK queue. + (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(counter1.count === 6) + assert(counter2.count === 6) + + // Make sure stopping works -- this requires putting a poison pill in all active queues, which + // would fail if our interrupted queue was still active, as its queue would be full. + bus.stop() + } + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -543,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } + /** + * A simple listener that interrupts on job end. + */ + private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (throwInterruptedException) { + throw new InterruptedException("got interrupted") + } else { + Thread.currentThread().interrupt() + } + } + } } // These classes can't be declared inside of the SparkListenerSuite class because we don't want @@ -583,3 +641,12 @@ private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends Spar case _ => } } + +private class SlowDeserializable extends Externalizable { + + override def writeExternal(out: ObjectOutput): Unit = { } + + override def readExternal(in: ObjectInput): Unit = Thread.sleep(1) + + def use(): Unit = { } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 97487ce1d2ca8..ba62eec0522db 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -62,6 +62,7 @@ class TaskDescriptionSuite extends SparkFunSuite { executorId = "testExecutor", name = "task for test", index = 19, + partitionId = 1, originalFiles, originalJars, originalProperties, @@ -77,6 +78,7 @@ class TaskDescriptionSuite extends SparkFunSuite { assert(decodedTaskDescription.executorId === originalTaskDescription.executorId) assert(decodedTaskDescription.name === originalTaskDescription.name) assert(decodedTaskDescription.index === originalTaskDescription.index) + assert(decodedTaskDescription.partitionId === originalTaskDescription.partitionId) assert(decodedTaskDescription.addedFiles.equals(originalFiles)) assert(decodedTaskDescription.addedJars.equals(originalJars)) assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 6003899bb7bef..7a457a0a72d90 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -36,6 +36,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def stop() {} def reviveOffers() {} def defaultParallelism(): Int = 1 + def maxNumConcurrentTasks(): Int = 0 } class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach @@ -62,7 +63,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } override def afterEach(): Unit = { - super.afterEach() if (taskScheduler != null) { taskScheduler.stop() taskScheduler = null @@ -71,6 +71,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B dagScheduler.stop() dagScheduler = null } + super.afterEach() } def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { @@ -917,4 +918,222 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.initialize(new FakeSchedulerBackend) } } + + test("Completions in zombie tasksets update status of non-zombie taskset") { + val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val valueSer = SparkEnv.get.serializer.newInstance() + + def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { + val indexInTsm = tsm.partitionToIndex(partition) + val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) + } + + // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, + // two times, so we have three active task sets for one stage. (For this to really happen, + // you'd need the previous stage to also get restarted, and then succeed, in between each + // attempt, but that happens outside what we're mocking here.) + val zombieAttempts = (0 until 2).map { stageAttempt => + val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) + taskScheduler.submitTasks(attempt) + val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get + val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 10) + // fail attempt + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + // the attempt is a zombie, but the tasks are still running (this could be true even if + // we actively killed those tasks, as killing is best-effort) + assert(tsm.isZombie) + assert(tsm.runningTasks === 9) + tsm + } + + // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for + // the stage, but this time with insufficient resources so not all tasks are active. + + val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) + taskScheduler.submitTasks(finalAttempt) + val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get + val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => + finalAttempt.tasks(task.index).partitionId + }.toSet + assert(finalTsm.runningTasks === 5) + assert(!finalTsm.isZombie) + + // We simulate late completions from our zombie tasksets, corresponding to all the pending + // partitions in our final attempt. This means we're only waiting on the tasks we've already + // launched. + val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) + finalAttemptPendingPartitions.foreach { partition => + completeTaskSuccessfully(zombieAttempts(0), partition) + } + + // If there is another resource offer, we shouldn't run anything. Though our final attempt + // used to have pending tasks, now those tasks have been completed by zombie attempts. The + // remaining tasks to compute are already active in the non-zombie attempt. + assert( + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) + + val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted + + // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be + // marked as zombie. + // for each of the remaining tasks, find the tasksets with an active copy of the task, and + // finish the task. + remainingTasks.foreach { partition => + val tsm = if (partition == 0) { + // we failed this task on both zombie attempts, this one is only present in the latest + // taskset + finalTsm + } else { + // should be active in every taskset. We choose a zombie taskset just to make sure that + // we transition the active taskset correctly even if the final completion comes + // from a zombie. + zombieAttempts(partition % 2) + } + completeTaskSuccessfully(tsm, partition) + } + + assert(finalTsm.isZombie) + + // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + + // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything + // else succeeds, to make sure we get the right updates to the blacklist in all cases. + (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => + val stageAttempt = tsm.taskSet.stageAttemptId + tsm.runningTasksSet.foreach { index => + if (stageAttempt == 1) { + tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) + } else { + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) + } + } + + // we update the blacklist for the stage attempts with all successful tasks. Even though + // some tasksets had failures, we still consider them all successful from a blacklisting + // perspective, as the failures weren't from a problem w/ the tasks themselves. + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + } + } + + test("don't schedule for a barrier taskSet if available slots are less than pending tasks") { + val taskCpus = 2 + val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) + + val numFreeCores = 3 + val workerOffers = IndexedSeq( + new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")), + new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627"))) + val attempt1 = FakeTask.createBarrierTaskSet(3) + + // submit attempt 1, offer some resources, since the available slots are less than pending + // tasks, don't schedule barrier tasks on the resource offer. + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions.length) + } + + test("schedule tasks for a barrier taskSet if all tasks can be launched together") { + val taskCpus = 2 + val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) + + val numFreeCores = 3 + val workerOffers = IndexedSeq( + new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")), + new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627")), + new WorkerOffer("executor2", "host2", numFreeCores, Some("192.168.0.101:49629"))) + val attempt1 = FakeTask.createBarrierTaskSet(3) + + // submit attempt 1, offer some resources, all tasks get launched together + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(3 === taskDescriptions.length) + } + + test("cancelTasks shall kill all the running tasks and fail the stage") { + val taskScheduler = setupScheduler() + + taskScheduler.initialize(new FakeSchedulerBackend { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Since we only submit one stage attempt, the following call is sufficient to mark the + // task as killed. + taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) + } + }) + + val attempt1 = FakeTask.createTaskSet(10, 0) + taskScheduler.submitTasks(attempt1) + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + assert(2 === tsm.runningTasks) + + taskScheduler.cancelTasks(0, false) + assert(0 === tsm.runningTasks) + assert(tsm.isZombie) + assert(taskScheduler.taskSetManagerForAttempt(0, 0).isEmpty) + } + + test("killAllTaskAttempts shall kill all the running tasks and not fail the stage") { + val taskScheduler = setupScheduler() + + taskScheduler.initialize(new FakeSchedulerBackend { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Since we only submit one stage attempt, the following call is sufficient to mark the + // task as killed. + taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) + } + }) + + val attempt1 = FakeTask.createTaskSet(10, 0) + taskScheduler.submitTasks(attempt1) + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + assert(2 === tsm.runningTasks) + + taskScheduler.killAllTaskAttempts(0, false, "test") + assert(0 === tsm.runningTasks) + assert(!tsm.isZombie) + assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined) + } + + test("mark taskset for a barrier stage as zombie in case a task fails") { + val taskScheduler = setupScheduler() + + val attempt = FakeTask.createBarrierTaskSet(3) + taskScheduler.submitTasks(attempt) + + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + val offers = (0 until 3).map{ idx => + WorkerOffer(s"exec-$idx", s"host-$idx", 1, Some(s"192.168.0.101:4962$idx")) + } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 3) + + // Fail a task from the stage attempt. + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, TaskKilled("test")) + assert(tsm.isZombie) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ca6a7e5db3b17..d264adaef90a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -178,12 +178,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } override def afterEach(): Unit = { - super.afterEach() if (sched != null) { sched.dagScheduler.stop() sched.stop() sched = null } + super.afterEach() } @@ -1365,10 +1365,241 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(taskOption4.get.addedJars === addedJarsMidTaskSet) } + test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.1") + sc.conf.set("spark.speculation", "true") + + sched = new FakeTaskScheduler(sc) + sched.initialize(new FakeSchedulerBackend()) + + val dagScheduler = new FakeDAGScheduler(sc, sched) + sched.setDAGScheduler(dagScheduler) + + val taskSet1 = FakeTask.createTaskSet(10) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task => + task.metrics.internalAccums + } + + sched.submitTasks(taskSet1) + sched.resourceOffers( + (0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get + + // fail fetch + taskSetManager1.handleFailedTask( + taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + + assert(taskSetManager1.isZombie) + assert(taskSetManager1.runningTasks === 9) + + val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1) + sched.submitTasks(taskSet2) + sched.resourceOffers( + (11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + // Complete the 2 tasks and leave 8 task in running + for (id <- Set(0, 1)) { + taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get + assert(!taskSetManager2.successfulTaskDurations.isEmpty()) + taskSetManager2.checkSpeculatableTasks(0) + } + + + test("SPARK-24755 Executor loss can cause task to not be resubmitted") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + + sc.conf.set("spark.speculation.quantile", "0.5") + sc.conf.set("spark.speculation", "true") + + var killTaskCalled = false + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Check the only one killTask event in this case, which triggered by + // task 2.1 completed. + assert(taskId === 2) + assert(executorId === "exec3") + assert(interruptThread) + assert(reason === "another attempt succeeded") + killTaskCalled = true + } + }) + + // Keep track of the index of tasks that are resubmitted, + // so that the test can check that task is resubmitted correctly + var resubmittedTasks = new mutable.HashSet[Int] + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += taskInfo.index + case _ => + } + } + } + sched.dagScheduler.stop() + sched.setDAGScheduler(dagScheduler) + + val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host3", "exec3")), + Seq(TaskLocation("host2", "exec2"))) + + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((exec, host) <- Seq( + "exec1" -> "host1", + "exec1" -> "host1", + "exec3" -> "host3", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(exec, host, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === exec) + // Add an extra assert to make sure task 2.0 is running on exec3 + if (task.index == 2) { + assert(task.attemptNumber === 0) + assert(task.executorId === "exec3") + } + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 2 tasks and leave 2 task in running + for (id <- Set(0, 1)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(2, 3)) + + // Offer resource to start the speculative attempt for the running task 2.0 + val taskOption = manager.resourceOffer("exec2", "host2", ANY) + assert(taskOption.isDefined) + val task4 = taskOption.get + assert(task4.index === 2) + assert(task4.taskId === 4) + assert(task4.executorId === "exec2") + assert(task4.attemptNumber === 1) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) + // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called + assert(killTaskCalled) + + assert(resubmittedTasks.isEmpty) + // Host 2 Losts, meaning we lost the map output task4 + manager.executorLost("exec2", "host2", SlaveLost()) + // Make sure that task with index 2 is re-submitted + assert(resubmittedTasks.contains(2)) + + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates) } + + test("SPARK-13343 speculative tasks that didn't commit shouldn't be marked as success") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation", "true") + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((k, v) <- List( + "exec1" -> "host1", + "exec1" -> "host1", + "exec2" -> "host2", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(k, v, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === k) + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 3 tasks and leave 1 task in running + for (id <- Set(0, 1, 2)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(3)) + + // Offer resource to start the speculative attempt for the running task + val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption5.isDefined) + val task5 = taskOption5.get + assert(task5.index === 3) + assert(task5.taskId === 4) + assert(task5.executorId === "exec1") + assert(task5.attemptNumber === 1) + sched.backend = mock(classOf[SchedulerBackend]) + sched.dagScheduler.stop() + sched.dagScheduler = mock(classOf[DAGScheduler]) + // Complete one attempt for the running task + val result = createTaskResult(3, accumUpdatesByTask(3)) + manager.handleSuccessfulTask(3, result) + // There is a race between the scheduler asking to kill the other task, and that task + // actually finishing. We simulate what happens if the other task finishes before we kill it. + verify(sched.backend).killTask(4, "exec1", true, "another attempt succeeded") + manager.handleSuccessfulTask(4, result) + + val info3 = manager.taskInfos(3) + val info4 = manager.taskInfos(4) + assert(info3.successful) + assert(info4.killed) + verify(sched.dagScheduler).taskEnded( + manager.tasks(3), + TaskKilled("Finish but did not commit due to another attempt succeeded"), + null, + Seq.empty, + info4) + verify(sched.dagScheduler).taskEnded(manager.tasks(3), Success, result.value(), + result.accumUpdates, info3) + } } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala index 3f52dc41abf6d..be6b8a6b5b108 100644 --- a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -28,11 +28,15 @@ trait EncryptionFunSuite { * for the test to modify the provided SparkConf. */ final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + encryptionTestHelper(name) { case (name, conf) => + test(name)(fn(conf)) + } + } + + final protected def encryptionTestHelper(name: String)(fn: (String, SparkConf) => Unit): Unit = { Seq(false, true).foreach { encrypt => - test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { - val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) - fn(conf) - } + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(s"$name (encryption = ${ if (encrypt) "on" else "off" })", conf) } } diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala new file mode 100644 index 0000000000000..e57cb701b6284 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import java.io.Closeable +import java.net._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +class SocketAuthHelperSuite extends SparkFunSuite { + + private val conf = new SparkConf() + private val authHelper = new SocketAuthHelper(conf) + + test("successful auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + authHelper.authToServer(client) + server.close() + server.join() + assert(server.error == null) + assert(server.authenticated) + } + } + } + + test("failed auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128)) + intercept[IllegalArgumentException] { + badHelper.authToServer(client) + } + server.close() + server.join() + assert(server.error != null) + assert(!server.authenticated) + } + } + } + + private class ServerThread extends Thread with Closeable { + + private val ss = new ServerSocket() + ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)) + + @volatile var error: Exception = _ + @volatile var authenticated = false + + setDaemon(true) + start() + + def createClient(): Socket = { + new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort()) + } + + override def run(): Unit = { + var clientConn: Socket = null + try { + clientConn = ss.accept() + authHelper.authClient(clientConn) + authenticated = true + } catch { + case e: Exception => + error = e + } finally { + Option(clientConn).foreach(_.close()) + } + } + + override def close(): Unit = { + try { + ss.close() + } finally { + interrupt() + } + } + + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index fc78655bf52ec..240f8cf800fe8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -345,7 +345,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize( + HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1)) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdbd..2d8a83c6fabed 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -108,7 +108,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala new file mode 100644 index 0000000000000..b9f0e873375b0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.lang.{Long => JLong} + +import org.mockito.Mockito.when +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.memory._ +import org.apache.spark.unsafe.Platform + +class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + + test("nested spill should be no-op") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set("spark.testing", "true") + .set("spark.testing.memory", "1600") + .set("spark.memory.fraction", "1") + sc = new SparkContext(conf) + + val memoryManager = UnifiedMemoryManager(conf, 1) + + var shouldAllocate = false + + // Mock `TaskMemoryManager` to allocate free memory when `shouldAllocate` is true. + // This will trigger a nested spill and expose issues if we don't handle this case properly. + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) { + override def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long = { + // ExecutionMemoryPool.acquireMemory will wait until there are 400 bytes for a task to use. + // So we leave 400 bytes for the task. + if (shouldAllocate && + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed > 400) { + val acquireExecutionMemoryMethod = + memoryManager.getClass.getMethods.filter(_.getName == "acquireExecutionMemory").head + acquireExecutionMemoryMethod.invoke( + memoryManager, + JLong.valueOf( + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed - 400), + JLong.valueOf(1L), // taskAttemptId + MemoryMode.ON_HEAP + ).asInstanceOf[java.lang.Long] + } + super.acquireExecutionMemory(required, consumer) + } + } + val taskContext = mock[TaskContext] + val taskMetrics = new TaskMetrics + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 100, // initialSize - This will require ShuffleInMemorySorter to acquire at least 800 bytes + 1, // numPartitions + conf, + new ShuffleWriteMetrics) + val inMemSorter = { + val field = sorter.getClass.getDeclaredField("inMemSorter") + field.setAccessible(true) + field.get(sorter).asInstanceOf[ShuffleInMemorySorter] + } + // Allocate memory to make the next "insertRecord" call triggers a spill. + val bytes = new Array[Byte](1) + while (inMemSorter.hasSpaceForAnotherRecord) { + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + + // This flag will make the mocked TaskMemoryManager acquire free memory released by spill to + // trigger a nested spill. + shouldAllocate = true + + // Should throw `SparkOutOfMemoryError` as there is no enough memory: `ShuffleInMemorySorter` + // will try to acquire 800 bytes but there are only 400 bytes available. + // + // Before the fix, a nested spill may use a released page and this causes two tasks access the + // same memory page. When a task reads memory written by another task, many types of failures + // may happen. Here are some examples we have seen: + // + // - JVM crash. (This is easy to reproduce in the unit test as we fill newly allocated and + // deallocated memory with 0xa5 and 0x5a bytes which usually points to an invalid memory + // address) + // - java.lang.IllegalArgumentException: Comparison method violates its general contract! + // - java.lang.NullPointerException + // at org.apache.spark.memory.TaskMemoryManager.getPage(TaskMemoryManager.java:384) + // - java.lang.UnsupportedOperationException: Cannot grow BufferHolder by size -536870912 + // because the size after growing exceeds size limitation 2147483632 + intercept[SparkOutOfMemoryError] { + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 1cd71955ad4d9..1b3639ad64a73 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -215,7 +215,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(wrapper.stageAttemptId === stages.head.attemptNumber) assert(wrapper.index === task.index) assert(wrapper.attempt === task.attemptNumber) assert(wrapper.launchTime === task.launchTime) @@ -258,7 +258,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { executorId = execIds.head, taskFailures = 2, stageId = stages.head.stageId, - stageAttemptId = stages.head.attemptId)) + stageAttemptId = stages.head.attemptNumber)) val executorStageSummaryWrappers = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") @@ -284,7 +284,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { hostId = "2.example.com", // this is where the second executor is hosted executorFailures = 1, stageId = stages.head.stageId, - stageAttemptId = stages.head.attemptId)) + stageAttemptId = stages.head.attemptNumber)) val executorStageSummaryWrappersForNode = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") @@ -468,7 +468,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { hostId = "1.example.com", executorFailures = 1, stageId = stages.last.stageId, - stageAttemptId = stages.last.attemptId)) + stageAttemptId = stages.last.attemptNumber)) check[ExecutorSummaryWrapper](execIds.head) { exec => assert(exec.info.blacklistedInStages === Set(stages.last.stageId)) @@ -963,17 +963,17 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // task end event. time += 1 val task = createTasks(1, Array("1")).head - listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptNumber, task)) time += 1 task.markFinished(TaskState.FINISHED, time) val metrics = TaskMetrics.empty metrics.setExecutorRunTime(42L) - listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptNumber, "taskType", Success, task, metrics)) new AppStatusStore(store) - .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) + .taskSummary(dropped.stageId, dropped.attemptNumber, Array(0.25d, 0.50d, 0.75d)) assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3) stages.drop(1).foreach { s => diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..08172f0b07b75 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1422,6 +1422,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) } + test("query locations of blockIds") { + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val blockLocations = Seq(BlockManagerId("1", "host1", 100), BlockManagerId("2", "host2", 200)) + when(mockBlockManagerMaster.getLocations(mc.any[Array[BlockId]])) + .thenReturn(Array(blockLocations)) + val env = mock(classOf[SparkEnv]) + + val blockIds: Array[BlockId] = Array(StreamBlockId(1, 2)) + val locs = BlockManager.blockIdsToLocations(blockIds, env, mockBlockManagerMaster) + val expectedLocs = Seq("executor_host1_1", "executor_host2_2") + assert(locs(blockIds(0)) == expectedLocs) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 var tempFileManager: TempFileManager = null diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index efdd02fff7871..eec961a491101 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -24,6 +24,7 @@ import com.google.common.io.{ByteStreams, Files} import io.netty.channel.FileRegion import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils @@ -94,7 +95,7 @@ class DiskStoreSuite extends SparkFunSuite { test("blocks larger than 2gb") { val conf = new SparkConf() - .set("spark.storage.memoryMapLimitForTests", "10k" ) + .set(config.MEMORY_MAP_LIMIT_FOR_TESTS.key, "10k") val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) @@ -194,8 +195,8 @@ class DiskStoreSuite extends SparkFunSuite { val region = data.toNetty().asInstanceOf[FileRegion] val byteChannel = new ByteArrayWritableChannel(data.size.toInt) - while (region.transfered() < region.count()) { - region.transferTo(byteChannel, region.transfered()) + while (region.transferred() < region.count()) { + region.transferTo(byteChannel, region.transferred()) } byteChannel.close() diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala index b21c91f75d5c7..42828506895a7 100644 --- a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -22,8 +22,8 @@ import org.apache.spark._ class FlatmapIteratorSuite extends SparkFunSuite with LocalSparkContext { /* Tests the ability of Spark to deal with user provided iterators from flatMap * calls, that may generate more data then available memory. In any - * memory based persistance Spark will unroll the iterator into an ArrayBuffer - * for caching, however in the case that the use defines DISK_ONLY persistance, + * memory based persistence Spark will unroll the iterator into an ArrayBuffer + * for caching, however in the case that the use defines DISK_ONLY persistence, * the iterator will be fed directly to the serializer and written to disk. * * This also tests the ObjectOutputStream reset rate. When serializing using the diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 692ae3bf597e0..a2997dbd1b1ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) + when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer } @@ -99,7 +100,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) - ) + ).toIterator val iterator = new ShuffleBlockFetcherIterator( TaskContext.empty(), @@ -176,7 +177,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -244,7 +245,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.size()).thenReturn(size) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + corruptBuffer + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) @@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) sem.release() @@ -310,7 +315,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -353,11 +358,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("big blocks are not checked for corruption") { - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - doReturn(10000L).when(corruptBuffer).size() + val corruptBuffer = mockCorruptBuffer(10000L) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -378,7 +379,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), (remoteBmId, remoteBlockLengths) - ) + ).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -413,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -428,16 +424,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) sem.release() } } }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -495,7 +491,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + def fetchShuffleBlock( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -513,17 +510,52 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. assert(tempFileManager != null) } + + test("fail zero-size blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() + ) + + val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress.toIterator, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + + // All blocks fetched return zero length and should trigger a receive-side error: + val e = intercept[FetchFailedException] { iterator.next() } + assert(e.getMessage.contains("Received a zero-size buffer")) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index a71521c91d2f2..cdc7f541b9552 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.storage +import javax.servlet.http.HttpServletRequest + import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite @@ -29,6 +31,7 @@ class StoragePageSuite extends SparkFunSuite { val storageTab = mock(classOf[StorageTab]) when(storageTab.basePath).thenReturn("http://localhost:4040") val storagePage = new StoragePage(storageTab, null) + val request = mock(classOf[HttpServletRequest]) test("rddTable") { val rdd1 = new RDDStorageInfo(1, @@ -61,7 +64,7 @@ class StoragePageSuite extends SparkFunSuite { None, None) - val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + val xmlNodes = storagePage.rddTable(request, Seq(rdd1, rdd2, rdd3)) val headers = Seq( "ID", @@ -94,7 +97,7 @@ class StoragePageSuite extends SparkFunSuite { } test("empty rddTable") { - assert(storagePage.rddTable(Seq.empty).isEmpty) + assert(storagePage.rddTable(request, Seq.empty).isEmpty) } test("streamBlockStorageLevelDescriptionAndSize") { diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index a04644d57ed88..94c79388e3639 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -162,4 +163,23 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.isZero) assert(acc3.value === "") } + + test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { + val param = new AccumulatorParam[MyData] { + override def zero(initialValue: MyData): MyData = new MyData(0) + override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) + } + + val acc = new LegacyAccumulatorWrapper(new MyData(0), param) + acc.metadata = AccumulatorMetadata( + AccumulatorContext.newId(), + Some("test"), + countFailedValues = false) + AccumulatorContext.register(acc) + + val ser = new JavaSerializer(new SparkConf).newInstance() + ser.serialize(acc) + } } + +class MyData(val i: Int) extends Serializable diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 9a19baee9569e..3c6660800f170 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.util import java.io.NotSerializableException +import scala.language.reflectiveCalls + import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.LocalSparkContext._ import org.apache.spark.partial.CountEvaluator @@ -121,6 +123,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass { val n2 = 222 val s2 = "bbb" @@ -141,6 +144,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass2 { val n2 = 222 val s2 = "bbb" @@ -154,6 +158,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: multiple outer classes have the same parent class") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass2 { val innerObject = new TestAbstractClass2 { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 278fada83d78c..96da8ec3b2a1c 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -145,6 +145,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get inner closure classes") { + assume(!ClosureCleanerSuite2.supportsLMFs) val closure1 = () => 1 val closure2 = () => { () => 1 } val closure3 = (i: Int) => { @@ -171,6 +172,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get outer classes and objects") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -207,6 +209,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get outer classes and objects with nesting") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val test1 = () => { @@ -258,6 +261,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -296,6 +300,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields with nesting") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val test1 = () => { @@ -538,17 +543,22 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // As before, this closure is neither serializable nor cleanable verifyCleaning(inner1, serializableBefore = false, serializableAfter = false) - // This closure is no longer serializable because it now has a pointer to the outer closure, - // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. - // If we do not clean transitively, we will not null out this indirect reference. - verifyCleaning( - inner2, serializableBefore = false, serializableAfter = false, transitive = false) - - // If we clean transitively, we will find that method `a` does not actually reference the - // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out - // the outer closure's parent pointer. This will make `inner2` serializable. - verifyCleaning( - inner2, serializableBefore = false, serializableAfter = true, transitive = true) + if (ClosureCleanerSuite2.supportsLMFs) { + verifyCleaning( + inner2, serializableBefore = true, serializableAfter = true) + } else { + // This closure is no longer serializable because it now has a pointer to the outer closure, + // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. + // If we do not clean transitively, we will not null out this indirect reference. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = false, transitive = false) + + // If we clean transitively, we will find that method `a` does not actually reference the + // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out + // the outer closure's parent pointer. This will make `inner2` serializable. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = true, transitive = true) + } } // Same as above, but with more levels of nesting @@ -565,4 +575,25 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri test6()()() } + test("verify nested non-LMF closures") { + assume(ClosureCleanerSuite2.supportsLMFs) + class A1(val f: Int => Int) + class A2(val f: Int => Int => Int) + class B extends A1(x => x*x) + class C extends A2(x => new B().f ) + val closure1 = new B().f + val closure2 = new C().f + // serializable already + verifyCleaning(closure1, serializableBefore = true, serializableAfter = true) + // brings in deps that can't be cleaned + verifyCleaning(closure2, serializableBefore = false, serializableAfter = false) + } +} + +object ClosureCleanerSuite2 { + // Scala 2.12 allows better interop with Java 8 via lambda syntax. This is supported + // by implementing FunctionN classes in Scala’s standard library as Single Abstract + // Method (SAM) types. Lambdas are implemented via the invokedynamic instruction and + // the use of the LambdaMwtaFactory (LMF) machanism. + val supportsLMFs = scala.util.Properties.versionString.contains("2.12") } diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index ae3b3d829f1bb..604f1e1ca3101 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite { "stack trace contains unexpected references to ThreadUtils" ) } + + test("parmap should be interruptible") { + val t = new Thread() { + setDaemon(true) + + override def run() { + try { + // "par" is uninterruptible. The following will keep running even if the thread is + // interrupted. We should prefer to use "ThreadUtils.parmap". + // + // (1 to 10).par.flatMap { i => + // Thread.sleep(100000) + // 1 to i + // } + // + ThreadUtils.parmap(1 to 10, "test", 2) { i => + Thread.sleep(100000) + 1 to i + }.flatten + } catch { + case _: InterruptedException => // excepted + } + } + } + t.start() + eventually(timeout(10.seconds)) { + assert(t.isAlive) + } + t.interrupt() + eventually(timeout(10.seconds)) { + assert(!t.isAlive) + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3b4273184f1e9..418d2f9b88500 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port") } } + + object MalformedClassObject { + class MalformedClass + } + + test("Safe getSimpleName") { + // getSimpleName on class of MalformedClass will result in error: Malformed class name + // Utils.getSimpleName works + val err = intercept[java.lang.InternalError] { + classOf[MalformedClassObject.MalformedClass].getSimpleName + } + assert(err.getMessage === "Malformed class name") + + assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === + "UtilsSuite$MalformedClassObject$MalformedClass") + } } private class SimpleExtension diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 35312f2d71131..d542ba0b6640d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -17,14 +17,24 @@ package org.apache.spark.util.collection +import java.util.Objects + import scala.collection.mutable.ArrayBuffer +import scala.ref.WeakReference + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually import org.apache.spark._ import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.util.CompletionIterator -class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { +class ExternalAppendOnlyMapSuite extends SparkFunSuite + with LocalSparkContext + with Eventually + with Matchers{ import TestUtils.{assertNotSpilled, assertSpilled} private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS @@ -414,7 +424,112 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("external aggregation updates peak execution memory") { + test("SPARK-22713 spill during iteration leaks internal map") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val it = map.iterator + assert(it.isInstanceOf[CompletionIterator[_, _]]) + // org.apache.spark.util.collection.AppendOnlyMap.destructiveSortedIterator returns + // an instance of an annonymous Iterator class. + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val first50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(map.numSpills == 0) + map.spill(Long.MaxValue, null) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + // assert(map.currentMap == null) + eventually { + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + + val next50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(!it.hasNext) + val keys = (first50Keys ++ next50Keys).sorted + assert(keys == (0 until 100)) + } + + test("drop all references to the underlying map once the iterator is exhausted") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val it = map.iterator + assert( it.isInstanceOf[CompletionIterator[_, _]]) + + + val keys = it.map{ + case (k, vs) => + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + .toList + .sorted + + assert(it.isEmpty) + assert(keys == (0 until 100)) + + assert(map.numSpills == 0) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + assert(map.currentMap == null) + + eventually { + Thread.sleep(500) + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + assert(it.toList.isEmpty) + } + + test("SPARK-22713 external aggregation updates peak execution memory") { val spillThreshold = 1000 val conf = createSparkConf(loadDefaults = false) .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 08a3200288981..151235dd0fb90 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -194,4 +194,50 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { val numInvalidValues = map.iterator.count(_._2 == 0) assertResult(0)(numInvalidValues) } + + test("distinguish between the 0/0.0/0L and null") { + val specializedMap1 = new OpenHashMap[String, Long] + specializedMap1("a") = null.asInstanceOf[Long] + specializedMap1("b") = 0L + assert(specializedMap1.contains("a")) + assert(!specializedMap1.contains("c")) + // null.asInstance[Long] will return 0L + assert(specializedMap1("a") === 0L) + assert(specializedMap1("b") === 0L) + // If the data type is in @specialized annotation, and + // the `key` is not be contained, the `map(key)` will return 0 + assert(specializedMap1("c") === 0L) + + val specializedMap2 = new OpenHashMap[String, Double] + specializedMap2("a") = null.asInstanceOf[Double] + specializedMap2("b") = 0.toDouble + assert(specializedMap2.contains("a")) + assert(!specializedMap2.contains("c")) + // null.asInstance[Double] will return 0.0 + assert(specializedMap2("a") === 0.0) + assert(specializedMap2("b") === 0.0) + assert(specializedMap2("c") === 0.0) + + val map1 = new OpenHashMap[String, Short] + map1("a") = null.asInstanceOf[Short] + map1("b") = 0.toShort + assert(map1.contains("a")) + assert(!map1.contains("c")) + // null.asInstance[Short] will return 0 + assert(map1("a") === 0) + assert(map1("b") === 0) + // If the data type is not in @specialized annotation, and + // the `key` is not be contained, the `map(key)` will return null + assert(map1("c") === null) + + val map2 = new OpenHashMap[String, Float] + map2("a") = null.asInstanceOf[Float] + map2("b") = 0.toFloat + assert(map2.contains("a")) + assert(!map2.contains("c")) + // null.asInstance[Float] will return 0.0 + assert(map2("a") === 0.0) + assert(map2("b") === 0.0) + assert(map2("c") === null) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 210bc5c099742..b887f937a9da9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -112,6 +112,80 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(!set.contains(10000L)) } + test("primitive float") { + val set = new OpenHashSet[Float] + assert(set.size === 0) + assert(!set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(10.1F) + assert(set.size === 1) + assert(set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 2) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(999.9F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + } + + test("primitive double") { + val set = new OpenHashSet[Double] + assert(set.size === 0) + assert(!set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(10.1D) + assert(set.size === 1) + assert(set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 2) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(999.9D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + } + test("non-primitive") { val set = new OpenHashSet[String] assert(set.size === 0) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 9552d001a079c..466135e72233a 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -11,6 +11,10 @@ cache .rat-excludes .*md derby.log +licenses/* +licenses-binary/* +LICENSE +NOTICE TAGS RELEASE control @@ -106,3 +110,4 @@ spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin kafka-source-initial-offset-future-version.bin +vote.tmpl diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index e6afb18558852..8a04b621f8ce4 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -81,7 +81,7 @@ if (!(Test-Path $tools)) { # ========================== Maven Push-Location $tools -$mavenVer = "3.3.9" +$mavenVer = "3.5.4" Start-FileDownload "https://archive.apache.org/dist/maven/maven-3/$mavenVer/binaries/apache-maven-$mavenVer-bin.zip" "maven.zip" # extract diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh new file mode 100755 index 0000000000000..fa7b73cdb40ec --- /dev/null +++ b/dev/create-release/do-release-docker.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Creates a Spark release candidate. The script will update versions, tag the branch, +# build Spark binary packages and documentation, and upload maven artifacts to a staging +# repository. There is also a dry run mode where only local builds are performed, and +# nothing is uploaded to the ASF repos. +# +# Run with "-h" for options. +# + +set -e +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +function usage { + local NAME=$(basename $0) + cat < "$GPG_KEY_FILE" + +run_silent "Building spark-rm image with tag $IMGTAG..." "docker-build.log" \ + docker build -t "spark-rm:$IMGTAG" --build-arg UID=$UID "$SELF/spark-rm" + +# Write the release information to a file with environment variables to be used when running the +# image. +ENVFILE="$WORKDIR/env.list" +fcreate_secure "$ENVFILE" + +function cleanup { + rm -f "$ENVFILE" + rm -f "$GPG_KEY_FILE" +} + +trap cleanup EXIT + +cat > $ENVFILE <> $ENVFILE + JAVA_VOL="--volume $JAVA:/opt/spark-java" +fi + +echo "Building $RELEASE_TAG; output will be at $WORKDIR/output" +docker run -ti \ + --env-file "$ENVFILE" \ + --volume "$WORKDIR:/opt/spark-rm" \ + $JAVA_VOL \ + "spark-rm:$IMGTAG" diff --git a/dev/create-release/do-release.sh b/dev/create-release/do-release.sh new file mode 100755 index 0000000000000..f1d4f3ab5ddec --- /dev/null +++ b/dev/create-release/do-release.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +while getopts "bn" opt; do + case $opt in + b) GIT_BRANCH=$OPTARG ;; + n) DRY_RUN=1 ;; + ?) error "Invalid option: $OPTARG" ;; + esac +done + +if [ "$RUNNING_IN_DOCKER" = "1" ]; then + # Inside docker, need to import the GPG key stored in the current directory. + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --import "$SELF/gpg.key" + + # We may need to adjust the path since JAVA_HOME may be overridden by the driver script. + if [ -n "$JAVA_HOME" ]; then + export PATH="$JAVA_HOME/bin:$PATH" + else + # JAVA_HOME for the openjdk package. + export JAVA_HOME=/usr + fi +else + # Outside docker, need to ask for information about the release. + get_release_info +fi + +function should_build { + local WHAT=$1 + [ -z "$RELEASE_STEP" ] || [ "$WHAT" = "$RELEASE_STEP" ] +} + +if should_build "tag" && [ $SKIP_TAG = 0 ]; then + run_silent "Creating release tag $RELEASE_TAG..." "tag.log" \ + "$SELF/release-tag.sh" + echo "It may take some time for the tag to be synchronized to github." + echo "Press enter when you've verified that the new tag ($RELEASE_TAG) is available." + read +else + echo "Skipping tag creation for $RELEASE_TAG." +fi + +if should_build "build"; then + run_silent "Building Spark..." "build.log" \ + "$SELF/release-build.sh" package +else + echo "Skipping build step." +fi + +if should_build "docs"; then + run_silent "Building documentation..." "docs.log" \ + "$SELF/release-build.sh" docs +else + echo "Skipping docs step." +fi + +if should_build "publish"; then + run_silent "Publishing release" "publish.log" \ + "$SELF/release-build.sh" publish-release +else + echo "Skipping publish step." +fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c00b00b845401..73610a3335910 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: release-build.sh @@ -87,49 +90,56 @@ NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) -MVN="build/mvn --force" - -# Hive-specific profiles for some builds -HIVE_PROFILES="-Phive -Phive-thriftserver" -# Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" -# Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" -# Scala 2.11 only profiles for some builds -SCALA_2_11_PROFILES="-Pkafka-0-8" -# Scala 2.12 only profiles for some builds -SCALA_2_12_PROFILES="-Pscala-2.12" +init_java +init_maven_sbt rm -rf spark -git clone https://git-wip-us.apache.org/repos/asf/spark.git +git clone "$ASF_REPO" cd spark git checkout $GIT_REF git_hash=`git rev-parse --short HEAD` echo "Checked out Spark git hash $git_hash" if [ -z "$SPARK_VERSION" ]; then - SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ - | grep -v INFO | grep -v WARNING | grep -v Download) + # Run $MVN in a separate command so that 'set -e' does the right thing. + TMP=$(mktemp) + $MVN help:evaluate -Dexpression=project.version > $TMP + SPARK_VERSION=$(cat $TMP | grep -v INFO | grep -v WARNING | grep -v Download) + rm $TMP fi -# Verify we have the right java version set -if [ -z "$JAVA_HOME" ]; then - echo "Please set JAVA_HOME." - exit 1 +# Depending on the version being built, certain extra profiles need to be activated, and +# different versions of Scala are supported. +BASE_PROFILES="-Pmesos -Pyarn" +PUBLISH_SCALA_2_10=0 +SCALA_2_10_PROFILES="-Pscala-2.10" +SCALA_2_11_PROFILES= +SCALA_2_12_PROFILES="-Pscala-2.12" + +if [[ $SPARK_VERSION > "2.3" ]]; then + BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" + SCALA_2_11_PROFILES="-Pkafka-0-8" +else + PUBLISH_SCALA_2_10=1 fi -java_version=$("${JAVA_HOME}"/bin/javac -version 2>&1 | cut -d " " -f 2) +# Hive-specific profiles for some builds +HIVE_PROFILES="-Phive -Phive-thriftserver" +# Profiles for publishing snapshots and release to Maven Central +PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +# Profiles for building binary releases +BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr" if [[ ! $SPARK_VERSION < "2.2." ]]; then - if [[ $java_version < "1.8." ]]; then - echo "Java version $java_version is less than required 1.8 for 2.2+" + if [[ $JAVA_VERSION < "1.8." ]]; then + echo "Java version $JAVA_VERSION is less than required 1.8 for 2.2+" echo "Please set JAVA_HOME correctly." exit 1 fi else - if [[ $java_version > "1.7." ]]; then + if ! [[ $JAVA_VERSION =~ 1\.7\..* ]]; then if [ -z "$JAVA_7_HOME" ]; then - echo "Java version $java_version is higher than required 1.7 for pre-2.2" + echo "Java version $JAVA_VERSION is higher than required 1.7 for pre-2.2" echo "Please set JAVA_HOME correctly." exit 1 else @@ -168,14 +178,20 @@ if [[ "$1" == "package" ]]; then SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION + ZINC_PORT=3035 + # Updated for each binary build make_binary_release() { NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - BUILD_PACKAGE=$4 - cp -r spark spark-$SPARK_VERSION-bin-$NAME + FLAGS="$MVN_EXTRA_OPTS -B $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES $2" + BUILD_PACKAGE=$3 + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + ZINC_PORT=$((ZINC_PORT + 1)) + echo "Building binary dist $NAME" + cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME # TODO There should probably be a flag to make-distribution to allow 2.12 support @@ -244,31 +260,58 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } - # TODO: Check exit codes of children here: - # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process + # List of binary packages built. Populates two associative arrays, where the key is the "name" of + # the package being built, and the values are respectively the needed maven arguments for building + # the package, and any extra package needed for that particular combination. + # + # In dry run mode, only build the first one. The keys in BINARY_PKGS_ARGS are used as the + # list of packages to be built, so it's ok for things to be missing in BINARY_PKGS_EXTRA. + + declare -A BINARY_PKGS_ARGS + BINARY_PKGS_ARGS["hadoop2.7"]="-Phadoop-2.7 $HIVE_PROFILES" + if ! is_dry_run; then + BINARY_PKGS_ARGS["hadoop2.6"]="-Phadoop-2.6 $HIVE_PROFILES" + BINARY_PKGS_ARGS["without-hadoop"]="-Pwithout-hadoop" + if [[ $SPARK_VERSION < "2.2." ]]; then + BINARY_PKGS_ARGS["hadoop2.4"]="-Phadoop-2.4 $HIVE_PROFILES" + BINARY_PKGS_ARGS["hadoop2.3"]="-Phadoop-2.3 $HIVE_PROFILES" + fi + fi - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - make_binary_release "hadoop2.6" "-Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr" & - make_binary_release "hadoop2.7" "-Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip" & - make_binary_release "without-hadoop" "-Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3038" & - wait - rm -rf spark-$SPARK_VERSION-bin-*/ + declare -A BINARY_PKGS_EXTRA + BINARY_PKGS_EXTRA["hadoop2.7"]="withpip" + if ! is_dry_run; then + BINARY_PKGS_EXTRA["hadoop2.6"]="withr" + fi + + echo "Packages to build: ${!BINARY_PKGS_ARGS[@]}" + for key in ${!BINARY_PKGS_ARGS[@]}; do + args=${BINARY_PKGS_ARGS[$key]} + extra=${BINARY_PKGS_EXTRA[$key]} + if ! make_binary_release "$key" "$args" "$extra"; then + error "Failed to build $key package. Check logs for details." + fi + done - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-bin" - mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + rm -rf spark-$SPARK_VERSION-bin-*/ - echo "Copying release tarballs" - cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" - svn add "svn-spark/${DEST_DIR_NAME}-bin" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-bin" + mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + + echo "Copying release tarballs" + cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" + svn add "svn-spark/${DEST_DIR_NAME}-bin" + + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" + cd .. + rm -rf svn-spark + fi - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" - cd .. - rm -rf svn-spark exit 0 fi @@ -282,18 +325,22 @@ if [[ "$1" == "docs" ]]; then cd .. cd .. - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-docs" - mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-docs" + mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" - echo "Copying release documentation" - cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" - svn add "svn-spark/${DEST_DIR_NAME}-docs" + echo "Copying release documentation" + cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" + svn add "svn-spark/${DEST_DIR_NAME}-docs" - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" - cd .. - rm -rf svn-spark + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" + cd .. + rm -rf svn-spark + fi + + mv "spark/docs/_site" docs/ exit 0 fi @@ -341,13 +388,15 @@ if [[ "$1" == "publish-release" ]]; then # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" + if ! is_dry_run; then + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + fi tmp_repo=$(mktemp -d spark-repo-XXXXX) @@ -356,6 +405,12 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install + if ! is_dry_run && [[ $PUBLISH_SCALA_2_10 = 1 ]]; then + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$((ZINC_PORT + 1)) -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ + -DskipTests $PUBLISH_PROFILES $SCALA_2_10_PROFILES clean install + fi + #./dev/change-scala-version.sh 2.12 #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo \ # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install @@ -371,31 +426,41 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc and .sha1 - it really doesn't like anything else there + # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done + if ! is_dry_run; then + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + fi - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" popd rm -rf $tmp_repo cd .. diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index a05716a5f66bb..628bc0504c9c8 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: tag-release.sh @@ -36,6 +39,7 @@ EOF } set -e +set -o pipefail if [[ $@ == *"help"* ]]; then exit_with_usage @@ -54,8 +58,10 @@ for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GI fi done +init_java +init_maven_sbt + ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" -MVN="build/mvn --force" rm -rf spark git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b $GIT_BRANCH @@ -94,9 +100,15 @@ sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION" git commit -a -m "Preparing development version $NEXT_VERSION" -# Push changes -git push origin $RELEASE_TAG -git push origin HEAD:$GIT_BRANCH - -cd .. -rm -rf spark +if ! is_dry_run; then + # Push changes + git push origin $RELEASE_TAG + git push origin HEAD:$GIT_BRANCH + + cd .. + rm -rf spark +else + cd .. + mv spark spark.tag + echo "Clone with version changes and tag available as spark.tag in the output directory." +fi diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh new file mode 100644 index 0000000000000..7426b0d6ca08d --- /dev/null +++ b/dev/create-release/release-util.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DRY_RUN=${DRY_RUN:-0} +GPG="gpg --no-tty --batch" +ASF_REPO="https://git-wip-us.apache.org/repos/asf/spark.git" +ASF_REPO_WEBUI="https://git-wip-us.apache.org/repos/asf?p=spark.git" + +function error { + echo "$*" + exit 1 +} + +function read_config { + local PROMPT="$1" + local DEFAULT="$2" + local REPLY= + + read -p "$PROMPT [$DEFAULT]: " REPLY + local RETVAL="${REPLY:-$DEFAULT}" + if [ -z "$RETVAL" ]; then + error "$PROMPT is must be provided." + fi + echo "$RETVAL" +} + +function parse_version { + grep -e '.*' | \ + head -n 2 | tail -n 1 | cut -d'>' -f2 | cut -d '<' -f1 +} + +function run_silent { + local BANNER="$1" + local LOG_FILE="$2" + shift 2 + + echo "========================" + echo "= $BANNER" + echo "Command: $@" + echo "Log file: $LOG_FILE" + + "$@" 1>"$LOG_FILE" 2>&1 + + local EC=$? + if [ $EC != 0 ]; then + echo "Command FAILED. Check full logs for details." + tail "$LOG_FILE" + exit $EC + fi +} + +function fcreate_secure { + local FPATH="$1" + rm -f "$FPATH" + touch "$FPATH" + chmod 600 "$FPATH" +} + +function check_for_tag { + curl -s --head --fail "$ASF_REPO_WEBUI;a=commit;h=$1" >/dev/null +} + +function get_release_info { + if [ -z "$GIT_BRANCH" ]; then + # If no branch is specified, found out the latest branch from the repo. + GIT_BRANCH=$(git ls-remote --heads "$ASF_REPO" | + grep -v refs/heads/master | + awk '{print $2}' | + sort -r | + head -n 1 | + cut -d/ -f3) + fi + + export GIT_BRANCH=$(read_config "Branch" "$GIT_BRANCH") + + # Find the current version for the branch. + local VERSION=$(curl -s "$ASF_REPO_WEBUI;a=blob_plain;f=pom.xml;hb=refs/heads/$GIT_BRANCH" | + parse_version) + echo "Current branch version is $VERSION." + + if [[ ! $VERSION =~ .*-SNAPSHOT ]]; then + error "Not a SNAPSHOT version: $VERSION" + fi + + NEXT_VERSION="$VERSION" + RELEASE_VERSION="${VERSION/-SNAPSHOT/}" + SHORT_VERSION=$(echo "$VERSION" | cut -d . -f 1-2) + local REV=$(echo "$VERSION" | cut -d . -f 3) + + # Find out what rc is being prepared. + # - If the current version is "x.y.0", then this is rc1 of the "x.y.0" release. + # - If not, need to check whether the previous version has been already released or not. + # - If it has, then we're building rc1 of the current version. + # - If it has not, we're building the next RC of the previous version. + local RC_COUNT + if [ $REV != 0 ]; then + local PREV_REL_REV=$((REV - 1)) + local PREV_REL_TAG="v${SHORT_VERSION}.${PREV_REL_REV}" + if check_for_tag "$PREV_REL_TAG"; then + RC_COUNT=1 + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + else + RELEASE_VERSION="${SHORT_VERSION}.${PREV_REL_REV}" + RC_COUNT=$(git ls-remote --tags "$ASF_REPO" "v${RELEASE_VERSION}-rc*" | wc -l) + RC_COUNT=$((RC_COUNT + 1)) + fi + else + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + RC_COUNT=1 + fi + + export NEXT_VERSION + export RELEASE_VERSION=$(read_config "Release" "$RELEASE_VERSION") + + RC_COUNT=$(read_config "RC #" "$RC_COUNT") + + # Check if the RC already exists, and if re-creating the RC, skip tag creation. + RELEASE_TAG="v${RELEASE_VERSION}-rc${RC_COUNT}" + SKIP_TAG=0 + if check_for_tag "$RELEASE_TAG"; then + read -p "$RELEASE_TAG already exists. Continue anyway [y/n]? " ANSWER + if [ "$ANSWER" != "y" ]; then + error "Exiting." + fi + SKIP_TAG=1 + fi + + + export RELEASE_TAG + + GIT_REF="$RELEASE_TAG" + if is_dry_run; then + echo "This is a dry run. Please confirm the ref that will be built for testing." + GIT_REF=$(read_config "Ref" "$GIT_REF") + fi + export GIT_REF + export SPARK_PACKAGE_VERSION="$RELEASE_TAG" + + # Gather some user information. + export ASF_USERNAME=$(read_config "ASF user" "$LOGNAME") + + GIT_NAME=$(git config user.name || echo "") + export GIT_NAME=$(read_config "Full name" "$GIT_NAME") + + export GIT_EMAIL="$ASF_USERNAME@apache.org" + export GPG_KEY=$(read_config "GPG key" "$GIT_EMAIL") + + cat <&1 | cut -d " " -f 2) + export JAVA_VERSION +} + +# Initializes MVN_EXTRA_OPTS and SBT_OPTS depending on the JAVA_VERSION in use. Requires init_java. +function init_maven_sbt { + MVN="build/mvn -B" + MVN_EXTRA_OPTS= + SBT_OPTS= + if [[ $JAVA_VERSION < "1.8." ]]; then + # Needed for maven central when using Java 7. + SBT_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN_EXTRA_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN="$MVN $MVN_EXTRA_OPTS" + fi + export MVN MVN_EXTRA_OPTS SBT_OPTS +} diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 32f6cbb29f0be..8cc990d871842 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -49,13 +49,16 @@ print("Install using 'sudo pip install unidecode'") sys.exit(-1) +if sys.version < '3': + input = raw_input # noqa + # Contributors list file name contributors_file_name = "contributors.txt" # Prompt the user to answer yes or no until they do so def yesOrNoPrompt(msg): - response = raw_input("%s [y/n]: " % msg) + response = input("%s [y/n]: " % msg) while response != "y" and response != "n": return yesOrNoPrompt(msg) return response == "y" @@ -149,7 +152,11 @@ def get_commits(tag): if not is_valid_author(author): author = github_username # Guard against special characters - author = unidecode.unidecode(unicode(author, "UTF-8")).strip() + try: # Python 2 + author = unicode(author, "UTF-8") + except NameError: # Python 3 + author = str(author) + author = unidecode.unidecode(author).strip() commit = Commit(_hash, author, title, pr_number) commits.append(commit) return commits diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile new file mode 100644 index 0000000000000..07ce320177f5a --- /dev/null +++ b/dev/create-release/spark-rm/Dockerfile @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building Spark releases. Based on Ubuntu 16.04. +# +# Includes: +# * Java 8 +# * Ivy +# * Python/PyPandoc (2.7.12/3.5.2) +# * R-base/R-base-dev (3.3.2+) +# * Ruby 2.3 build utilities + +FROM ubuntu:16.04 + +# These arguments are just for reuse and not really meant to be customized. +ARG APT_INSTALL="apt-get install --no-install-recommends -y" + +ARG BASE_PIP_PKGS="setuptools wheel virtualenv" +ARG PIP_PKGS="pyopenssl pypandoc numpy pygments sphinx" + +# Install extra needed repos and refresh. +# - CRAN repo +# - Ruby repo (for doc generation) +# +# This is all in a single "RUN" command so that if anything changes, "apt update" is run to fetch +# the most current package versions (instead of potentially using old versions cached by docker). +RUN echo 'deb http://cran.cnr.Berkeley.edu/bin/linux/ubuntu xenial/' >> /etc/apt/sources.list && \ + gpg --keyserver keyserver.ubuntu.com --recv-key E084DAB9 && \ + gpg -a --export E084DAB9 | apt-key add - && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean && \ + apt-get update && \ + $APT_INSTALL software-properties-common && \ + apt-add-repository -y ppa:brightbox/ruby-ng && \ + apt-get update && \ + # Install openjdk 8. + $APT_INSTALL openjdk-8-jdk && \ + update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java && \ + # Install build / source control tools + $APT_INSTALL curl wget git maven ivy subversion make gcc lsof libffi-dev \ + pandoc pandoc-citeproc libssl-dev libcurl4-openssl-dev libxml2-dev && \ + ln -s -T /usr/share/java/ivy.jar /usr/share/ant/lib/ivy.jar && \ + curl -sL https://deb.nodesource.com/setup_4.x | bash && \ + $APT_INSTALL nodejs && \ + # Install needed python packages. Use pip for installing packages (for consistency). + $APT_INSTALL libpython2.7-dev libpython3-dev python-pip python3-pip && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + cd && \ + virtualenv -p python3 p35 && \ + . p35/bin/activate && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + # Install R packages and dependencies used when building. + # R depends on pandoc*, libssl (which are installed above). + $APT_INSTALL r-base r-base-dev && \ + $APT_INSTALL texlive-latex-base texlive texlive-fonts-extra texinfo qpdf && \ + Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='http://cran.us.r-project.org/')" && \ + Rscript -e "devtools::install_github('jimhester/lintr')" && \ + # Install tools needed to build the documentation. + $APT_INSTALL ruby2.3 ruby2.3-dev && \ + gem install jekyll --no-rdoc --no-ri && \ + gem install jekyll-redirect-from && \ + gem install pygments.rb + +WORKDIR /opt/spark-rm/output + +ARG UID +RUN useradd -m -s /bin/bash -p spark-rm -u $UID spark-rm +USER spark-rm:spark-rm + +ENTRYPOINT [ "/opt/spark-rm/do-release.sh" ] diff --git a/dev/create-release/vote.tmpl b/dev/create-release/vote.tmpl new file mode 100644 index 0000000000000..2ce953c2f7ec4 --- /dev/null +++ b/dev/create-release/vote.tmpl @@ -0,0 +1,65 @@ +Please vote on releasing the following candidate as Apache Spark version {version}. + +The vote is open until {deadline} and passes if a majority +1 PMC votes are cast, with +a minimum of 3 +1 votes. + +[ ] +1 Release this package as Apache Spark {version} +[ ] -1 Do not release this package because ... + +To learn more about Apache Spark, please see http://spark.apache.org/ + +The tag to be voted on is {tag} (commit {tag_commit}): +https://github.com/apache/spark/tree/{tag} + +The release files, including signatures, digests, etc. can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-bin/ + +Signatures used for Spark RCs can be found in this file: +https://dist.apache.org/repos/dist/dev/spark/KEYS + +The staging repository for this release can be found at: +https://repository.apache.org/content/repositories/orgapachespark-{repo_id}/ + +The documentation corresponding to this release can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-docs/ + +The list of bug fixes going into {version} can be found at the following URL: +https://issues.apache.org/jira/projects/SPARK/versions/{jira_version_id} + +FAQ + +========================= +How can I help test this release? +========================= + +If you are a Spark user, you can help us test this release by taking +an existing Spark workload and running on this release candidate, then +reporting any regressions. + +If you're working in PySpark you can set up a virtual env and install +the current RC and see if anything important breaks, in the Java/Scala +you can add the staging repository to your projects resolvers and test +with the RC (make sure to clean up the artifact cache before/after so +you don't end up building with a out of date RC going forward). + +=========================================== +What should happen to JIRA tickets still targeting {version}? +=========================================== + +The current list of open tickets targeted at {version} can be found at: +{open_issues_link} + +Committers should look at those and triage. Extremely important bug +fixes, documentation, and API tweaks that impact compatibility should +be worked on immediately. Everything else please retarget to an +appropriate release. + +================== +But my bug isn't fixed? +================== + +In order to make timely releases, we will typically not hold the +release unless the bug in question is a regression from the previous +release. That being said, if there is something which is a regression +that has not been correctly targeted please ping me or a committer to +help target the issue. \ No newline at end of file diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index c3d1dd444b506..fc42af905c2fe 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -14,15 +14,13 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.8.0.jar -arrow-memory-0.8.0.jar -arrow-vector-0.8.0.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar automaton-1.11-8.jar -avro-1.7.7.jar -avro-ipc-1.7.7.jar -avro-mapred-1.7.7-hadoop2.jar -base64-2.3.8.jar -bcprov-jdk15on-1.58.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -36,8 +34,8 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.8.jar -commons-compress-1.4.1.jar +commons-compiler-3.0.9.jar +commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar commons-dbcp-1.4.jar @@ -86,8 +84,8 @@ hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar hppc-0.7.2.jar htrace-core-3.0.4.jar -httpclient-4.5.4.jar -httpcore-4.4.8.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar ivy-2.4.0.jar jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar @@ -100,8 +98,7 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.8.jar -java-xmlbuilder-1.1.jar +janino-3.0.9.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -119,10 +116,9 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -157,25 +153,26 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar @@ -190,12 +187,12 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar -xz-1.0.jar +xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 290867035f91d..54e50556b4620 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -14,15 +14,13 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.8.0.jar -arrow-memory-0.8.0.jar -arrow-vector-0.8.0.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar automaton-1.11-8.jar -avro-1.7.7.jar -avro-ipc-1.7.7.jar -avro-mapred-1.7.7-hadoop2.jar -base64-2.3.8.jar -bcprov-jdk15on-1.58.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -36,8 +34,8 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.8.jar -commons-compress-1.4.1.jar +commons-compiler-3.0.9.jar +commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar commons-dbcp-1.4.jar @@ -66,28 +64,28 @@ gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.3.jar -hadoop-auth-2.7.3.jar -hadoop-client-2.7.3.jar -hadoop-common-2.7.3.jar -hadoop-hdfs-2.7.3.jar -hadoop-mapreduce-client-app-2.7.3.jar -hadoop-mapreduce-client-common-2.7.3.jar -hadoop-mapreduce-client-core-2.7.3.jar -hadoop-mapreduce-client-jobclient-2.7.3.jar -hadoop-mapreduce-client-shuffle-2.7.3.jar -hadoop-yarn-api-2.7.3.jar -hadoop-yarn-client-2.7.3.jar -hadoop-yarn-common-2.7.3.jar -hadoop-yarn-server-common-2.7.3.jar -hadoop-yarn-server-web-proxy-2.7.3.jar +hadoop-annotations-2.7.7.jar +hadoop-auth-2.7.7.jar +hadoop-client-2.7.7.jar +hadoop-common-2.7.7.jar +hadoop-hdfs-2.7.7.jar +hadoop-mapreduce-client-app-2.7.7.jar +hadoop-mapreduce-client-common-2.7.7.jar +hadoop-mapreduce-client-core-2.7.7.jar +hadoop-mapreduce-client-jobclient-2.7.7.jar +hadoop-mapreduce-client-shuffle-2.7.7.jar +hadoop-yarn-api-2.7.7.jar +hadoop-yarn-client-2.7.7.jar +hadoop-yarn-common-2.7.7.jar +hadoop-yarn-server-common-2.7.7.jar +hadoop-yarn-server-web-proxy-2.7.7.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar hppc-0.7.2.jar htrace-core-3.1.0-incubating.jar -httpclient-4.5.4.jar -httpcore-4.4.8.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar ivy-2.4.0.jar jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar @@ -100,8 +98,7 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.8.jar -java-xmlbuilder-1.1.jar +janino-3.0.9.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -119,10 +116,10 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.4.jar jetty-6.1.26.jar +jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -158,25 +155,26 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar @@ -191,12 +189,12 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar -xz-1.0.jar +xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 new file mode 100644 index 0000000000000..ff5713b5b66b7 --- /dev/null +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -0,0 +1,218 @@ +HikariCP-java7-2.4.12.jar +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +accessors-smart-1.2.jar +activation-1.1.1.jar +aircompressor-0.10.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.7.jar +aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar +automaton-1.11-8.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.13.2.jar +breeze_2.11-0.13.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.4.jar +chill_2.11-0.8.4.jar +commons-beanutils-1.9.3.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-3.0.9.jar +commons-compress-1.8.1.jar +commons-configuration2-2.1.1.jar +commons-crypto-1.0.0.jar +commons-daemon-1.0.13.jar +commons-dbcp-1.4.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.5.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-3.1.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.12.0.jar +curator-framework-2.12.0.jar +curator-recipes-2.12.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.12.1.1.jar +dnsjava-2.1.7.jar +ehcache-3.3.1.jar +eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar +geronimo-jcache_1.0_spec-1.0-alpha-1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-4.0.jar +guice-servlet-4.0.jar +hadoop-annotations-3.1.0.jar +hadoop-auth-3.1.0.jar +hadoop-client-3.1.0.jar +hadoop-common-3.1.0.jar +hadoop-hdfs-client-3.1.0.jar +hadoop-mapreduce-client-common-3.1.0.jar +hadoop-mapreduce-client-core-3.1.0.jar +hadoop-mapreduce-client-jobclient-3.1.0.jar +hadoop-yarn-api-3.1.0.jar +hadoop-yarn-client-3.1.0.jar +hadoop-yarn-common-3.1.0.jar +hadoop-yarn-registry-3.1.0.jar +hadoop-yarn-server-common-3.1.0.jar +hadoop-yarn-server-web-proxy-3.1.0.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +hppc-0.7.2.jar +htrace-core4-4.1.0-incubating.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar +ivy-2.4.0.jar +jackson-annotations-2.6.7.jar +jackson-core-2.6.7.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar +jackson-jaxrs-base-2.7.8.jar +jackson-jaxrs-json-provider-2.7.8.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar +jackson-module-paranamer-2.7.9.jar +jackson-module-scala_2.11-2.6.7.1.jar +janino-3.0.9.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar +javax.inject-1.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.11.jar +jcip-annotations-1.0-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar +jetty-webapp-9.3.24.v20180605.jar +jetty-xml-9.3.24.v20180605.jar +jline-2.14.6.jar +joda-time-2.9.3.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-smart-2.3.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kerb-admin-1.0.1.jar +kerb-client-1.0.1.jar +kerb-common-1.0.1.jar +kerb-core-1.0.1.jar +kerb-crypto-1.0.1.jar +kerb-identity-1.0.1.jar +kerb-server-1.0.1.jar +kerb-simplekdc-1.0.1.jar +kerb-util-1.0.1.jar +kerby-asn1-1.0.1.jar +kerby-config-1.0.1.jar +kerby-pkix-1.0.1.jar +kerby-util-1.0.1.jar +kerby-xdr-1.0.1.jar +kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar +leveldbjni-all-1.8.jar +libfb303-0.9.3.jar +libthrift-0.9.3.jar +log4j-1.2.17.jar +logging-interceptor-3.8.1.jar +lz4-java-1.4.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar +mesos-1.4.0-shaded-protobuf.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar +minlog-1.3.0.jar +mssql-jdbc-6.2.1.jre7.jar +netty-3.9.9.Final.jar +netty-all-4.1.17.Final.jar +nimbus-jose-jwt-4.41.1.jar +objenesis-2.1.jar +okhttp-2.7.5.jar +okhttp-3.8.1.jar +okio-1.13.0.jar +opencsv-2.3.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar +oro-2.0.8.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.8.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.10.0.jar +protobuf-java-2.5.0.jar +py4j-0.10.7.jar +pyrolite-4.13.jar +re2j-1.1.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar +scala-xml_2.11-1.0.5.jar +shapeless_2.11-2.3.2.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar +snappy-0.2.jar +snappy-java-1.1.7.1.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar +stax-api-1.0.1.jar +stax2-api-3.1.4.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +token-provider-1.0.1.jar +univocity-parsers-2.7.3.jar +validation-api-1.1.0.Final.jar +woodstox-core-5.0.3.jar +xbean-asm6-shaded-4.8.jar +xz-1.5.jar +zjsonpatch-0.3.0.jar +zookeeper-3.4.9.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 84233c64caa9c..ad99ce55806af 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -211,9 +211,10 @@ mkdir -p "$DISTDIR/examples/src/main" cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files -cp "$SPARK_HOME/LICENSE" "$DISTDIR" -cp -r "$SPARK_HOME/licenses" "$DISTDIR" -cp "$SPARK_HOME/NOTICE" "$DISTDIR" +cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" +mkdir -p "$DISTDIR/licenses" +cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" +cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 5ea205fbed4aa..28a6714856c10 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -39,6 +39,9 @@ except ImportError: JIRA_IMPORTED = False +if sys.version < '3': + input = raw_input # noqa + # Location of your Spark git development area SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site @@ -95,20 +98,21 @@ def run_cmd(cmd): def continue_maybe(prompt): - result = raw_input("\n%s (y/n): " % prompt) + result = input("\n%s (y/n): " % prompt) if result.lower() != "y": fail("Okay, exiting") def clean_up(): - print("Restoring head pointer to %s" % original_head) - run_cmd("git checkout %s" % original_head) + if 'original_head' in globals(): + print("Restoring head pointer to %s" % original_head) + run_cmd("git checkout %s" % original_head) - branches = run_cmd("git branch").replace(" ", "").split("\n") + branches = run_cmd("git branch").replace(" ", "").split("\n") - for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): - print("Deleting local branch %s" % branch) - run_cmd("git branch -D %s" % branch) + for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): + print("Deleting local branch %s" % branch) + run_cmd("git branch -D %s" % branch) # merge the requested PR and return the merge hash @@ -133,11 +137,16 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = raw_input( + primary_author = input( "Enter primary author in the format of \"name \" [%s]: " % distinct_authors[0]) if primary_author == "": primary_author = distinct_authors[0] + else: + # When primary author is specified manually, de-dup it from author list and + # put it at the head of author list. + distinct_authors = list(filter(lambda x: x != primary_author, distinct_authors)) + distinct_authors.insert(0, primary_author) commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") @@ -150,13 +159,10 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): # to people every time someone creates a public fork of Spark. merge_message_flags += ["-m", body.replace("@", "")] - authors = "\n".join(["Author: %s" % a for a in distinct_authors]) - - merge_message_flags += ["-m", authors] + committer_name = run_cmd("git config --get user.name").strip() + committer_email = run_cmd("git config --get user.email").strip() if had_conflicts: - committer_name = run_cmd("git config --get user.name").strip() - committer_email = run_cmd("git config --get user.email").strip() message = "This patch had conflicts when merged, resolved by\nCommitter: %s <%s>" % ( committer_name, committer_email) merge_message_flags += ["-m", message] @@ -164,6 +170,14 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): # The string "Closes #%s" string is required for GitHub to correctly close the PR merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] + authors = "Authored-by:" if len(distinct_authors) == 1 else "Lead-authored-by:" + authors += " %s" % (distinct_authors.pop(0)) + if len(distinct_authors) > 0: + authors += "\n" + "\n".join(["Co-authored-by: %s" % a for a in distinct_authors]) + authors += "\n" + "Signed-off-by: %s <%s>" % (committer_name, committer_email) + + merge_message_flags += ["-m", authors] + run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) continue_maybe("Merge complete (local ref %s). Push to %s?" % ( @@ -183,7 +197,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): def cherry_pick(pr_num, merge_hash, default_branch): - pick_ref = raw_input("Enter a branch name [%s]: " % default_branch) + pick_ref = input("Enter a branch name [%s]: " % default_branch) if pick_ref == "": pick_ref = default_branch @@ -230,7 +244,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): asf_jira = jira.client.JIRA({'server': JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) - jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) + jira_id = input("Enter a JIRA id [%s]: " % default_jira_id) if jira_id == "": jira_id = default_jira_id @@ -275,7 +289,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): default_fix_versions = filter(lambda x: x != v, default_fix_versions) default_fix_versions = ",".join(default_fix_versions) - fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) + fix_versions = input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) if fix_versions == "": fix_versions = default_fix_versions fix_versions = fix_versions.replace(" ", "").split(",") @@ -314,8 +328,8 @@ def choose_jira_assignee(issue, asf_jira): if author in commentors: annotations.append("Commentor") print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - raw_assignee = raw_input( - "Enter number of user, or userid, to assign to (blank to leave unassigned):") + raw_assignee = input( + "Enter number of user, or userid, to assign to (blank to leave unassigned):") if raw_assignee == "": return None else: @@ -327,6 +341,8 @@ def choose_jira_assignee(issue, asf_jira): assignee = asf_jira.user(raw_assignee) asf_jira.assign_issue(issue.key, assignee.key) return assignee + except KeyboardInterrupt: + raise except: traceback.print_exc() print("Error assigning JIRA, try again (or leave blank and fix manually)") @@ -358,8 +374,8 @@ def standardize_jira_ref(text): >>> standardize_jira_ref("[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl") '[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl' >>> standardize_jira_ref( - ... "SPARK-1094 Support MiMa for reporting binary compatibility accross versions.") - '[SPARK-1094] Support MiMa for reporting binary compatibility accross versions.' + ... "SPARK-1094 Support MiMa for reporting binary compatibility across versions.") + '[SPARK-1094] Support MiMa for reporting binary compatibility across versions.' >>> standardize_jira_ref("[WIP] [SPARK-1146] Vagrant support for Spark") '[SPARK-1146][WIP] Vagrant support for Spark' >>> standardize_jira_ref( @@ -427,7 +443,7 @@ def main(): # Assumes branch names can be sorted lexicographically latest_branch = sorted(branch_names, reverse=True)[0] - pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr_num = input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) @@ -439,7 +455,7 @@ def main(): print("I've re-written the title as follows to match the standard format:") print("Original: %s" % pr["title"]) print("Modified: %s" % modified_title) - result = raw_input("Would you like to use the modified title? (y/n): ") + result = input("Would you like to use the modified title? (y/n): ") if result.lower() == "y": title = modified_title print("Using modified title:") @@ -490,7 +506,7 @@ def main(): merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) pick_prompt = "Would you like to pick %s into another branch?" % merge_hash - while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + while input("\n%s (y/n): " % pick_prompt).lower() == "y": merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: diff --git a/dev/requirements.txt b/dev/requirements.txt index 79782279f8fbd..fa833ab96b8e7 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -2,3 +2,4 @@ jira==1.0.3 PyGithub==1.26.0 Unidecode==0.04.19 pypandoc==1.3.3 +sphinx diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 1321c2be4c192..60cf4d8209416 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -52,7 +52,7 @@ if hash virtualenv 2>/dev/null && [ ! -n "$USE_CONDA" ]; then PYTHON_EXECS+=('python3') fi elif hash conda 2>/dev/null; then - echo "Using conda virtual enviroments" + echo "Using conda virtual environments" PYTHON_EXECS=('3.5') USE_CONDA=1 else @@ -88,8 +88,8 @@ for python in "${PYTHON_EXECS[@]}"; do virtualenv --python=$python "$VIRTUALENV_PATH" source "$VIRTUALENV_PATH"/bin/activate fi - # Upgrade pip & friends if using virutal env - if [ ! -n "USE_CONDA" ]; then + # Upgrade pip & friends if using virtual env + if [ ! -n "$USE_CONDA" ]; then pip install --upgrade pip pypandoc wheel numpy fi @@ -123,7 +123,7 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR" - # conda / virtualenv enviroments need to be deactivated differently + # conda / virtualenv environments need to be deactivated differently if [ -n "$USE_CONDA" ]; then source deactivate else diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 3960a0de62530..e6fe3b82ed202 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -181,8 +181,9 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 350m) - tests_timeout = "300m" + # must be less than the timeout configured on Jenkins. Usually Jenkins's timeout is higher + # then this. Please consult with the build manager or a committer when it should be increased. + tests_timeout = "400m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/dev/run-tests.py b/dev/run-tests.py index 164c1e2200aa9..d9d3789ac1255 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', + ['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() @@ -204,7 +204,7 @@ def run_scala_style_checks(): def run_java_style_checks(): set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") - run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + run_cmd([os.path.join(SPARK_HOME, "dev", "sbt-checkstyle")]) def run_python_style_checks(): @@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version): exec_sbt(profiles_and_goals) -def build_spark_assembly_sbt(hadoop_version): +def build_spark_assembly_sbt(hadoop_version, checkstyle=False): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["assembly/package"] @@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version): " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + if checkstyle: + run_java_style_checks() + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the # documentation build fails on a specific machine & environment in Jenkins but it was unable @@ -570,12 +573,13 @@ def main(): or f.endswith("scalastyle-config.xml") for f in changed_files): run_scala_style_checks() + should_run_java_style_checks = False if not changed_files or any(f.endswith(".java") or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - # run_java_style_checks() - pass + # Run SBT Checkstyle after the build to prevent a side-effect to the build. + should_run_java_style_checks = True if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") @@ -604,7 +608,7 @@ def main(): detect_binary_inop_with_mima(hadoop_version) # Since we did not build assembly/package before running dev/mima, we need to # do it here because the tests still rely on it; see SPARK-13294 for details. - build_spark_assembly_sbt(hadoop_version) + build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks) # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) diff --git a/dev/sbt-checkstyle b/dev/sbt-checkstyle new file mode 100755 index 0000000000000..8821a7c0e4ccf --- /dev/null +++ b/dev/sbt-checkstyle @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pmesos \ + -Pkafka-0-8 \ + -Pkubernetes \ + -Pyarn \ + -Pflume \ + -Phive \ + -Phive-thriftserver \ + checkstyle test:checkstyle \ + | awk '{if($1~/error/)print}' \ +) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi + diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index dfea762db98c6..2aa355504bf29 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -170,6 +170,16 @@ def __hash__(self): ] ) +avro = Module( + name="avro", + dependencies=[sql], + source_file_regexes=[ + "external/avro", + ], + sbt_test_goals=[ + "avro/test", + ] +) sql_kafka = Module( name="sql-kafka-0-10", diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 3bf7618e1ea96..2fbd6b5e98f7f 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -34,6 +34,7 @@ MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 hadoop-2.7 + hadoop-3.1 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to diff --git a/dev/tox.ini b/dev/tox.ini index 583c1eaaa966b..28dad8f3b5c7c 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -16,4 +16,4 @@ [pycodestyle] ignore=E402,E731,E241,W503,E226,E722,E741,E305 max-line-length=100 -exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* +exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*,dist/* diff --git a/docs/README.md b/docs/README.md index dbea4d64c4298..7da543dd297ad 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,7 +2,7 @@ Welcome to the Spark documentation! This readme will walk you through navigating and building the Spark documentation, which is included here with the Spark source code. You can also find documentation specific to release versions of -Spark at http://spark.apache.org/documentation.html. +Spark at https://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the documentation yourself. Why build it yourself? So that you have the docs that correspond to @@ -79,7 +79,7 @@ jekyll plugin to run `build/sbt unidoc` before building the site so if you haven may take some time as it generates all of the scaladoc and javadoc using [Unidoc](https://github.com/sbt/sbt-unidoc). The jekyll plugin also generates the PySpark docs using [Sphinx](http://sphinx-doc.org/), SparkR docs using [roxygen2](https://cran.r-project.org/web/packages/roxygen2/index.html) and SQL docs -using [MkDocs](http://www.mkdocs.org/). +using [MkDocs](https://www.mkdocs.org/). NOTE: To skip the step of building and copying over the Scala, Java, Python, R and SQL API docs, run `SKIP_API=1 jekyll build`. In addition, `SKIP_SCALADOC=1`, `SKIP_PYTHONDOC=1`, `SKIP_RDOC=1` and `SKIP_SQLDOC=1` can be used diff --git a/docs/_layouts/404.html b/docs/_layouts/404.html index 044654413f9c2..78f98b9ede5a7 100755 --- a/docs/_layouts/404.html +++ b/docs/_layouts/404.html @@ -151,7 +151,7 @@

    Not found :(

    - + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index e5af5ae4561c7..88d549c3f1010 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -50,7 +50,7 @@ @@ -114,8 +114,8 @@
  • Hardware Provisioning
  • Building Spark
  • -
  • Contributing to Spark
  • -
  • Third Party Projects
  • +
  • Contributing to Spark
  • +
  • Third Party Projects
  • diff --git a/docs/avro-data-source-guide.md b/docs/avro-data-source-guide.md new file mode 100644 index 0000000000000..d3b81f029d377 --- /dev/null +++ b/docs/avro-data-source-guide.md @@ -0,0 +1,380 @@ +--- +layout: global +title: Apache Avro Data Source Guide +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +Since Spark 2.4 release, [Spark SQL](https://spark.apache.org/docs/latest/sql-programming-guide.html) provides built-in support for reading and writing Apache Avro data. + +## Deploying +The `spark-avro` module is external and not included in `spark-submit` or `spark-shell` by default. + +As with any Spark applications, `spark-submit` is used to launch your application. `spark-avro_{{site.SCALA_BINARY_VERSION}}` +and its dependencies can be directly added to `spark-submit` using `--packages`, such as, + + ./bin/spark-submit --packages org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +For experimenting on `spark-shell`, you can also use `--packages` to add `org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}` and its dependencies directly, + + ./bin/spark-shell --packages org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +See [Application Submission Guide](submitting-applications.html) for more details about submitting applications with external dependencies. + +## Load and Save Functions + +Since `spark-avro` module is external, there is no `.avro` API in +`DataFrameReader` or `DataFrameWriter`. + +To load/save data in Avro format, you need to specify the data source option `format` as `avro`(or `org.apache.spark.sql.avro`). +
    +
    +{% highlight scala %} + +val usersDF = spark.read.format("avro").load("examples/src/main/resources/users.avro") +usersDF.select("name", "favorite_color").write.format("avro").save("namesAndFavColors.avro") + +{% endhighlight %} +
    +
    +{% highlight java %} + +Dataset usersDF = spark.read().format("avro").load("examples/src/main/resources/users.avro"); +usersDF.select("name", "favorite_color").write().format("avro").save("namesAndFavColors.avro"); + +{% endhighlight %} +
    +
    +{% highlight python %} + +df = spark.read.format("avro").load("examples/src/main/resources/users.avro") +df.select("name", "favorite_color").write.format("avro").save("namesAndFavColors.avro") + +{% endhighlight %} +
    +
    +{% highlight r %} + +df <- read.df("examples/src/main/resources/users.avro", "avro") +write.df(select(df, "name", "favorite_color"), "namesAndFavColors.avro", "avro") + +{% endhighlight %} +
    +
    + +## to_avro() and from_avro() +The Avro package provides function `to_avro` to encode a column as binary in Avro +format, and `from_avro()` to decode Avro binary data into a column. Both functions transform one column to +another column, and the input/output SQL data type can be complex type or primitive type. + +Using Avro record as columns are useful when reading from or writing to a streaming source like Kafka. Each +Kafka key-value record will be augmented with some metadata, such as the ingestion timestamp into Kafka, the offset in Kafka, etc. +* If the "value" field that contains your data is in Avro, you could use `from_avro()` to extract your data, enrich it, clean it, and then push it downstream to Kafka again or write it out to a file. +* `to_avro()` can be used to turn structs into Avro records. This method is particularly useful when you would like to re-encode multiple columns into a single one when writing data out to Kafka. + +Both functions are currently only available in Scala and Java. + +
    +
    +{% highlight scala %} +import org.apache.spark.sql.avro._ + +// `from_avro` requires Avro schema in JSON string format. +val jsonFormatSchema = new String(Files.readAllBytes(Paths.get("./examples/src/main/resources/user.avsc"))) + +val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + +// 1. Decode the Avro data into a struct; +// 2. Filter by column `favorite_color`; +// 3. Encode the column `name` in Avro format. +val output = df + .select(from_avro('value, jsonFormatSchema) as 'user) + .where("user.favorite_color == \"red\"") + .select(to_avro($"user.name") as 'value) + +val query = output + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start() + +{% endhighlight %} +
    +
    +{% highlight java %} +import org.apache.spark.sql.avro.*; + +// `from_avro` requires Avro schema in JSON string format. +String jsonFormatSchema = new String(Files.readAllBytes(Paths.get("./examples/src/main/resources/user.avsc"))); + +Dataset df = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load(); + +// 1. Decode the Avro data into a struct; +// 2. Filter by column `favorite_color`; +// 3. Encode the column `name` in Avro format. +Dataset output = df + .select(from_avro(col("value"), jsonFormatSchema).as("user")) + .where("user.favorite_color == \"red\"") + .select(to_avro(col("user.name")).as("value")); + +StreamingQuery query = output + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start(); + +{% endhighlight %} +
    +
    + +## Data Source Option + +Data source options of Avro can be set using the `.option` method on `DataFrameReader` or `DataFrameWriter`. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaningScope
    avroSchemaNoneOptional Avro schema provided by an user in JSON format. The date type and naming of record fields + should match the input Avro data or Catalyst data, otherwise the read/write action will fail.read and write
    recordNametopLevelRecordTop level record name in write result, which is required in Avro spec.write
    recordNamespace""Record namespace in write result.write
    ignoreExtensiontrueThe option controls ignoring of files without .avro extensions in read.
    If the option is enabled, all files (with and without .avro extension) are loaded.
    read
    compressionsnappyThe compression option allows to specify a compression codec used in write.
    + Currently supported codecs are uncompressed, snappy, deflate, bzip2 and xz.
    If the option is not set, the configuration spark.sql.avro.compression.codec config is taken into account.
    write
    + +## Configuration +Configuration of Avro can be done using the `setConf` method on SparkSession or by running `SET key=value` commands using SQL. + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.legacy.replaceDatabricksSparkAvro.enabledtrueIf it is set to true, the data source provider com.databricks.spark.avro is mapped to the built-in but external Avro data source module for backward compatibility.
    spark.sql.avro.compression.codecsnappyCompression codec used in writing of AVRO files. Supported codecs: uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.
    spark.sql.avro.deflate.level-1Compression level for the deflate codec used in writing of AVRO files. Valid value must be in the range of from 1 to 9 inclusive or -1. The default value is -1 which corresponds to 6 level in the current implementation.
    + +## Compatibility with Databricks spark-avro +This Avro data source module is originally from and compatible with Databricks's open source repository +[spark-avro](https://github.com/databricks/spark-avro). + +By default with the SQL configuration `spark.sql.legacy.replaceDatabricksSparkAvro.enabled` enabled, the data source provider `com.databricks.spark.avro` is +mapped to this built-in Avro module. For the Spark tables created with `Provider` property as `com.databricks.spark.avro` in +catalog meta store, the mapping is essential to load these tables if you are using this built-in Avro module. + +Note in Databricks's [spark-avro](https://github.com/databricks/spark-avro), implicit classes +`AvroDataFrameWriter` and `AvroDataFrameReader` were created for shortcut function `.avro()`. In this +built-in but external module, both implicit classes are removed. Please use `.format("avro")` in +`DataFrameWriter` or `DataFrameReader` instead, which should be clean and good enough. + +If you prefer using your own build of `spark-avro` jar file, you can simply disable the configuration +`spark.sql.legacy.replaceDatabricksSparkAvro.enabled`, and use the option `--jars` on deploying your +applications. Read the [Advanced Dependency Management](https://spark.apache +.org/docs/latest/submitting-applications.html#advanced-dependency-management) section in Application +Submission Guide for more details. + +## Supported types for Avro -> Spark SQL conversion +Currently Spark supports reading all [primitive types](https://avro.apache.org/docs/1.8.2/spec.html#schema_primitive) and [complex types](https://avro.apache.org/docs/1.8.2/spec.html#schema_complex) under records of Avro. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Avro typeSpark SQL type
    booleanBooleanType
    intIntegerType
    longLongType
    floatFloatType
    doubleDoubleType
    stringStringType
    enumStringType
    fixedBinaryType
    bytesBinaryType
    recordStructType
    arrayArrayType
    mapMapType
    unionSee below
    + +In addition to the types listed above, it supports reading `union` types. The following three types are considered basic `union` types: + +1. `union(int, long)` will be mapped to LongType. +2. `union(float, double)` will be mapped to DoubleType. +3. `union(something, null)`, where something is any supported Avro type. This will be mapped to the same Spark SQL type as that of something, with nullable set to true. +All other union types are considered complex. They will be mapped to StructType where field names are member0, member1, etc., in accordance with members of the union. This is consistent with the behavior when converting between Avro and Parquet. + +It also supports reading the following Avro [logical types](https://avro.apache.org/docs/1.8.2/spec.html#Logical+Types): + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Avro logical typeAvro typeSpark SQL type
    dateintDateType
    timestamp-millislongTimestampType
    timestamp-microslongTimestampType
    decimalfixedDecimalType
    decimalbytesDecimalType
    +At the moment, it ignores docs, aliases and other properties present in the Avro file. + +## Supported types for Spark SQL -> Avro conversion +Spark supports writing of all Spark SQL types into Avro. For most types, the mapping from Spark types to Avro types is straightforward (e.g. IntegerType gets converted to int); however, there are a few special cases which are listed below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Spark SQL typeAvro typeAvro logical type
    ByteTypeint
    ShortTypeint
    BinaryTypebytes
    DateTypeintdate
    TimestampTypelongtimestamp-micros
    DecimalTypefixeddecimal
    + +You can also specify the whole output Avro schema with the option `avroSchema`, so that Spark SQL types can be converted into other Avro types. The following conversions are not applied by default and require user specified Avro schema: + + + + + + + + + + + + + + + + + + + + + + + +
    Spark SQL typeAvro typeAvro logical type
    BinaryTypefixed
    StringTypeenum
    TimestampTypelongtimestamp-millis
    DecimalTypebytesdecimal
    diff --git a/docs/building-spark.md b/docs/building-spark.md index 0236bb05849ad..d3dfd4902a920 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -45,7 +45,7 @@ Other build examples can be found below. ## Building a Runnable Distribution To create a Spark distribution like those distributed by the -[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as +[Spark Downloads](https://spark.apache.org/downloads.html) page, and that is laid out so as to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: @@ -67,7 +67,7 @@ Examples: ./build/mvn -Pyarn -DskipTests clean package # Apache Hadoop 2.7.X and later - ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.3 -DskipTests clean package + ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.7 -DskipTests clean package ## Building With Hive and JDBC Support @@ -164,7 +164,7 @@ prompt. Developers who compile Spark frequently may want to speed up compilation; e.g., by using Zinc (for developers who build with Maven) or by avoiding re-compilation of the assembly JAR (for developers who build with SBT). For more information about how to do this, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html#reducing-build-times). ## Encrypted Filesystems @@ -182,7 +182,7 @@ to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/ ## IntelliJ IDEA or Eclipse For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html). # Running Tests @@ -203,7 +203,7 @@ The following is an example of a command to run the tests: ## Running Individual Tests For information about how to run individual tests, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#running-individual-tests). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html#running-individual-tests). ## PySpark pip installable @@ -215,19 +215,23 @@ If you are building Spark for use in a Python environment and you wish to pip in Alternatively, you can also run make-distribution with the --pip option. -## PySpark Tests with Maven +## PySpark Tests with Maven or SBT If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. ./build/mvn -DskipTests clean package -Phive ./python/run-tests +If you are building PySpark with SBT and wish to run the PySpark tests, you will need to build Spark with Hive support and also build the test components: + + ./build/sbt -Phive clean package + ./build/sbt test:compile + ./python/run-tests + The run-tests script also can be limited to a specific Python version or a specific module ./python/run-tests --python-executables=python --modules=pyspark-sql -**Note:** You can also run Python tests with an sbt build, provided you build Spark with Hive support. - ## Running R Tests To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index ac1c336988930..36753f6373b55 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -70,7 +70,7 @@ be safely used as the direct destination of work with the normal rename-based co ### Installation With the relevant libraries on the classpath and Spark configured with valid credentials, -objects can be can be read or written by using their URLs as the path to data. +objects can be read or written by using their URLs as the path to data. For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. @@ -104,7 +104,7 @@ Spark jobs must authenticate with the object stores to access data within them. and `AWS_SESSION_TOKEN` environment variables and sets the associated authentication options for the `s3n` and `s3a` connectors to Amazon S3. 1. In a Hadoop cluster, settings may be set in the `core-site.xml` file. -1. Authentication details may be manually added to the Spark configuration in `spark-default.conf` +1. Authentication details may be manually added to the Spark configuration in `spark-defaults.conf` 1. Alternatively, they can be programmatically set in the `SparkConf` instance used to configure the application's `SparkContext`. @@ -184,7 +184,8 @@ is no need for a workflow of write-then-rename to ensure that files aren't picke while they are still being written. Applications can write straight to the monitored directory. 1. Streams should only be checkpointed to a store implementing a fast and -atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. +atomic `rename()` operation. +Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/configuration.md b/docs/configuration.md index 4d4d0c58dd07d..9714b48d5e69b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -179,6 +179,18 @@ of the most common options to set are: (e.g. 2g, 8g). + + spark.executor.pyspark.memory + Not set + + The amount of memory to be allocated to PySpark in each executor, in MiB + unless otherwise specified. If set, PySpark memory for an executor will be + limited to this amount. If not set, Spark will not limit Python's memory use + and it is up to the application to avoid exceeding the overhead memory space + shared with other non-JVM processes. When PySpark is run in YARN, this memory + is added to executor resource requests. + + spark.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 @@ -208,7 +220,7 @@ of the most common options to set are: stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. - NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or + NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone), MESOS_SANDBOX (Mesos) or LOCAL_DIRS (YARN) environment variables set by the cluster manager. @@ -328,6 +340,11 @@ Apart from these, the following properties are also available, and may be useful Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. + + The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by + application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable + verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: + -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc @@ -575,13 +592,15 @@ Apart from these, the following properties are also available, and may be useful spark.maxRemoteBlockSizeFetchToMem - Long.MaxValue + Int.MaxValue - 512 The remote block will be fetched to disk when size of the block is above this threshold in bytes. - This is to avoid a giant request takes too much memory. We can enable this config by setting - a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + This is to avoid a giant request that takes too much memory. By default, this is only enabled + for blocks > 2GB, as those cannot be fetched directly into memory, no matter what resources are + available. But it can be turned down to a much lower value (eg. 200m) to avoid using too much + memory on smaller blocks as well. Note this configuration will affect both shuffle fetch and block manager remote block fetch. For users who enabled external shuffle service, - this feature can only be worked when external shuffle service is newer than Spark 2.2. + this feature can only be used when external shuffle service is newer than Spark 2.2. @@ -905,8 +924,8 @@ Apart from these, the following properties are also available, and may be useful lz4 The codec used to compress internal data such as RDD partitions, event log, broadcast variables - and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, - and snappy. You can also use fully qualified class names to specify the codec, + and shuffle outputs. By default, Spark provides four codecs: lz4, lzf, + snappy, and zstd. You can also use fully qualified class names to specify the codec, e.g. org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, @@ -1210,6 +1229,15 @@ Apart from these, the following properties are also available, and may be useful if it is too small, BlockManager might take a performance hit. + + spark.broadcast.checksum + true + + Whether to enable checksum for broadcast. If enabled, broadcasts will include a checksum, which can + help detect corrupted blocks, at the cost of computing and sending a little more data. It's possible + to disable it if the network has other mechanisms to guarantee data won't be corrupted during broadcast. + + spark.executor.cores @@ -1624,9 +1652,10 @@ Apart from these, the following properties are also available, and may be useful spark.blacklist.killBlacklistedExecutors false - (Experimental) If set to "true", allow Spark to automatically kill, and attempt to re-create, - executors when they are blacklisted. Note that, when an entire node is added to the blacklist, - all of the executors on that node will be killed. + (Experimental) If set to "true", allow Spark to automatically kill the executors + when they are blacklisted on fetch failure or blacklisted for the entire application, + as controlled by spark.blacklist.application.*. Note that, when an entire node is added + to the blacklist, all of the executors on that node will be killed. @@ -1753,6 +1782,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and spark.dynamicAllocation.initialExecutors + spark.dynamicAllocation.executorAllocationRatio @@ -1797,6 +1827,23 @@ Apart from these, the following properties are also available, and may be useful Lower bound for the number of executors if dynamic allocation is enabled. + + spark.dynamicAllocation.executorAllocationRatio + 1 + + By default, the dynamic allocation will request enough executors to maximize the + parallelism according to the number of tasks to process. While this minimizes the + latency of the job, with small tasks this setting can waste a lot of resources due to + executor allocation overhead, as some executor might not even do any work. + This setting allows to set a ratio that will be used to reduce the number of + executors w.r.t. full parallelism. + Defaults to 1.0 to give maximum parallelism. + 0.5 will divide the target number of executors by 2 + The target number of executors computed by the dynamicAllocation can still be overridden + by the spark.dynamicAllocation.minExecutors and + spark.dynamicAllocation.maxExecutors settings + + spark.dynamicAllocation.schedulerBacklogTimeout 1s @@ -2178,7 +2225,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes files are set cluster-wide, and cannot safely be changed by the application. The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`. -They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf` +They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defaults.conf` In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties. diff --git a/docs/contributing-to-spark.md b/docs/contributing-to-spark.md index 9252545e4a129..ede5584a0cf99 100644 --- a/docs/contributing-to-spark.md +++ b/docs/contributing-to-spark.md @@ -5,4 +5,4 @@ title: Contributing to Spark The Spark team welcomes all forms of contributions, including bug reports, documentation or patches. For the newest information on how to contribute to the project, please read the -[Contributing to Spark guide](http://spark.apache.org/contributing.html). +[Contributing to Spark guide](https://spark.apache.org/contributing.html). diff --git a/docs/index.md b/docs/index.md index 2f009417fafb0..40f628b794c01 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,7 +12,7 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog # Downloading -Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. +Get Spark from the [downloads page](https://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. Users can also download a "Hadoop free" binary and run Spark with any Hadoop version [by augmenting Spark's classpath](hadoop-provided.html). Scala and Java users can include Spark in their projects using its Maven coordinates and in the future Python users can also install Spark from PyPI. @@ -111,7 +111,7 @@ options for deployment: * [Amazon EC2](https://github.com/amplab/spark-ec2): scripts that let you launch a cluster on EC2 in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager * [Mesos](running-on-mesos.html): deploy a private cluster using - [Apache Mesos](http://mesos.apache.org) + [Apache Mesos](https://mesos.apache.org) * [YARN](running-on-yarn.html): deploy Spark on top of Hadoop NextGen (YARN) * [Kubernetes](running-on-kubernetes.html): deploy Spark on top of Kubernetes @@ -127,20 +127,20 @@ options for deployment: * [Cloud Infrastructures](cloud-integration.html) * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system -* [Contributing to Spark](http://spark.apache.org/contributing.html) -* [Third Party Projects](http://spark.apache.org/third-party-projects.html): related third party Spark projects +* [Contributing to Spark](https://spark.apache.org/contributing.html) +* [Third Party Projects](https://spark.apache.org/third-party-projects.html): related third party Spark projects **External Resources:** -* [Spark Homepage](http://spark.apache.org) -* [Spark Community](http://spark.apache.org/community.html) resources, including local meetups +* [Spark Homepage](https://spark.apache.org) +* [Spark Community](https://spark.apache.org/community.html) resources, including local meetups * [StackOverflow tag `apache-spark`](http://stackoverflow.com/questions/tagged/apache-spark) -* [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here +* [Mailing Lists](https://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/6/), [slides](http://ampcamp.berkeley.edu/6/) and [exercises](http://ampcamp.berkeley.edu/6/exercises/) are available online for free. -* [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), +* [Code Examples](https://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index da90342406c84..2316f175676ee 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -264,3 +264,11 @@ within it for the various settings. For example: A full example is also available in `conf/fairscheduler.xml.template`. Note that any pools not configured in the XML file will simply get default values for all settings (scheduling mode FIFO, weight 1, and minShare 0). + +## Scheduling using JDBC Connections +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + +{% highlight SQL %} +SET spark.sql.thriftserver.scheduler.pool=accounting; +{% endhighlight %} diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index d660655e193eb..b3d109039da4d 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -455,11 +455,29 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat ## Naive Bayes [Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple -probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence -assumptions between the features. The `spark.ml` implementation currently supports both [multinomial -naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) +probabilistic, multiclass classifiers based on applying Bayes' theorem with strong (naive) independence +assumptions between every pair of features. + +Naive Bayes can be trained very efficiently. With a single pass over the training data, +it computes the conditional probability distribution of each feature given each label. +For prediction, it applies Bayes' theorem to compute the conditional probability distribution +of each label given an observation. + +MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). + +*Input data*: +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each feature represents a term. +A feature's value is the frequency of the term (in multinomial Naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes). +Feature values must be *non-negative*. The model type is selected with an optional parameter +"multinomial" or "bernoulli" with "multinomial" as the default. +For document classification, the input feature vectors should usually be sparse vectors. +Since the training data is only used once, it is not necessary to cache it. + +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +setting the parameter $\lambda$ (default to $1.0$). **Examples** diff --git a/docs/ml-features.md b/docs/ml-features.md index 7aed2341584fc..882b895a9d154 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -585,7 +585,11 @@ for more details on the API. ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. -The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The indices are in `[0, numLabels)`, and four ordering options are supported: +"frequencyDesc": descending order by label frequency (most frequent label assigned 0), +"frequencyAsc": ascending order by label frequency (least frequent label assigned 0), +"alphabetDesc": descending alphabetical order, and "alphabetAsc": ascending alphabetical order +(default = "frequencyDesc"). The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or @@ -1429,7 +1433,7 @@ for more details on the API. ## Imputer -The `Imputer` transformer completes missing values in a dataset, either using the mean or the +The `Imputer` estimator completes missing values in a dataset, either using the mean or the median of the columns in which the missing values are located. The input columns should be of `DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly creates incorrect values for columns containing categorical features. Imputer can impute custom values @@ -1593,10 +1597,25 @@ Suppose `a` and `b` are double columns, we use the following simple examples to * `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` are coefficients. `RFormula` produces a vector column of features and a double or string column of label. -Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. -If the label column is of type string, it will be first transformed to double with `StringIndexer`. +Like when formulas are used in R for linear regression, numeric columns will be cast to doubles. +As to string input columns, they will first be transformed with [StringIndexer](ml-features.html#stringindexer) using ordering determined by `stringOrderType`, +and the last category after ordering is dropped, then the doubles will be one-hot encoded. + +Suppose a string feature column containing values `{'b', 'a', 'b', 'a', 'c', 'b'}`, we set `stringOrderType` to control the encoding: +~~~ +stringOrderType | Category mapped to 0 by StringIndexer | Category dropped by RFormula +----------------|---------------------------------------|--------------------------------- +'frequencyDesc' | most frequent category ('b') | least frequent category ('c') +'frequencyAsc' | least frequent category ('c') | most frequent category ('b') +'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a') +'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') +~~~ + +If the label column is of type string, it will be first transformed to double with [StringIndexer](ml-features.html#stringindexer) using `frequencyDesc` ordering. If the label column does not exist in the DataFrame, the output label column will be created from the specified response variable in the formula. +**Note:** The ordering option `stringOrderType` is NOT used for the label column. When the label column is indexed, it uses the default descending frequency ordering in `StringIndexer`. + **Examples** Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`: diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index e4736411fb5fe..2047065f71eb8 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -289,7 +289,7 @@ In the `spark.mllib` package, there were several breaking changes. The first ch In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: -* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in `spark.ml` which used to use SchemaRDD now use DataFrame. +* The old [SchemaRDD](https://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in `spark.ml` which used to use SchemaRDD now use DataFrame. * In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. * Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. diff --git a/docs/ml-statistics.md b/docs/ml-statistics.md index abfb3cab1e566..6c82b3bb94b24 100644 --- a/docs/ml-statistics.md +++ b/docs/ml-statistics.md @@ -89,4 +89,32 @@ Refer to the [`ChiSquareTest` Python docs](api/python/index.html#pyspark.ml.stat {% include_example python/ml/chi_square_test_example.py %} + + +## Summarizer + +We provide vector column summary statistics for `Dataframe` through `Summarizer`. +Available metrics are the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. + +
    +
    +The following example demonstrates using [`Summarizer`](api/scala/index.html#org.apache.spark.ml.stat.Summarizer$) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example scala/org/apache/spark/examples/ml/SummarizerExample.scala %} +
    + +
    +The following example demonstrates using [`Summarizer`](api/java/org/apache/spark/ml/stat/Summarizer.html) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example java/org/apache/spark/examples/ml/JavaSummarizerExample.java %} +
    + +
    +Refer to the [`Summarizer` Python docs](api/python/index.html#pyspark.ml.stat.Summarizer$) for details on the API. + +{% include_example python/ml/summarizer_example.py %} +
    +
    \ No newline at end of file diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 5066bb29387dc..eca101132d2e5 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -317,7 +317,7 @@ Refer to the [`Matrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib. from pyspark.mllib.linalg import Matrix, Matrices # Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) -dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) +dm2 = Matrices.dense(3, 2, [1, 3, 5, 2, 4, 6]) # Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) @@ -624,7 +624,7 @@ from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry # Create an RDD of coordinate entries. # - This can be done explicitly with the MatrixEntry class: -entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]) +entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(2, 1, 3.7)]) # - or using (long, long, float) tuples: entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]) diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index d9dbbab4840a3..c65ecdcb67ee4 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -462,13 +462,13 @@ $$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{ Normalized Discounted Cumulative Gain $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} - \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+2)}} \\ \text{Where} \\ \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ - \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+2)}$ - NDCG at k is a + NDCG at k is a measure of how many of the first k recommended documents are in the set of true relevant documents averaged across all users. In contrast to precision at k, this metric takes into account the order of the recommendations (documents are assumed to be in order of decreasing relevance). diff --git a/docs/monitoring.md b/docs/monitoring.md index 6eaf33135744d..2717dd091c751 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -435,6 +435,7 @@ set of sinks to which metrics are reported. The following instances are currentl * `executor`: A Spark executor. * `driver`: The Spark driver process (the process in which your SparkContext is created). * `shuffleService`: The Spark shuffle service. +* `applicationMaster`: The Spark ApplicationMaster when running on YARN. Each instance can report to zero or more _sinks_. Sinks are contained in the `org.apache.spark.metrics.sink` package: diff --git a/docs/quick-start.md b/docs/quick-start.md index f1a2096cd4dbd..ef7af6c3f6cec 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -12,7 +12,7 @@ interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. To follow along with this guide, first, download a packaged release of Spark from the -[Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, +[Spark website](https://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index b6424090d2fea..d95b757f36859 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -106,7 +106,7 @@ You can also use `bin/pyspark` to launch an interactive Python shell. If you wish to access HDFS data, you need to use a build of PySpark linking to your version of HDFS. -[Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage +[Prebuilt packages](https://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. Finally, you need to import some Spark classes into your program. Add the following line: @@ -1569,7 +1569,7 @@ as Spark does not support two contexts running concurrently in the same program. # Where to Go from Here -You can see some [example Spark programs](http://spark.apache.org/examples.html) on the Spark website. +You can see some [example Spark programs](https://spark.apache.org/examples.html) on the Spark website. In addition, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e9e1f3e280609..c83dad6df1e7b 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -117,6 +117,45 @@ If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0. spark-submit. Finally, notice that in the above example we specify a jar with a specific URI with a scheme of `local://`. This URI is the location of the example jar that is already in the Docker image. +## Client Mode + +Starting with Spark 2.4.0, it is possible to run Spark applications on Kubernetes in client mode. When your application +runs in client mode, the driver can run inside a pod or on a physical host. When running an application in client mode, +it is recommended to account for the following factors: + +### Client Mode Networking + +Spark executors must be able to connect to the Spark driver over a hostname and a port that is routable from the Spark +executors. The specific network configuration that will be required for Spark to work in client mode will vary per +setup. If you run your driver inside a Kubernetes pod, you can use a +[headless service](https://kubernetes.io/docs/concepts/services-networking/service/#headless-services) to allow your +driver pod to be routable from the executors by a stable hostname. When deploying your headless service, ensure that +the service's label selector will only match the driver pod and no other pods; it is recommended to assign your driver +pod a sufficiently unique label and to use that label in the label selector of the headless service. Specify the driver's +hostname via `spark.driver.host` and your spark driver's port to `spark.driver.port`. + +### Client Mode Executor Pod Garbage Collection + +If you run your Spark driver in a pod, it is highly recommended to set `spark.driver.pod.name` to the name of that pod. +When this property is set, the Spark scheduler will deploy the executor pods with an +[OwnerReference](https://kubernetes.io/docs/concepts/workloads/controllers/garbage-collection/), which in turn will +ensure that once the driver pod is deleted from the cluster, all of the application's executor pods will also be deleted. +The driver will look for a pod with the given name in the namespace specified by `spark.kubernetes.namespace`, and +an OwnerReference pointing to that pod will be added to each executor pod's OwnerReferences list. Be careful to avoid +setting the OwnerReference to a pod that is not actually that driver pod, or else the executors may be terminated +prematurely when the wrong pod is deleted. + +If your application is not running inside a pod, or if `spark.driver.pod.name` is not set when your application is +actually running in a pod, keep in mind that the executor pods may not be properly deleted from the cluster when the +application exits. The Spark scheduler attempts to delete these pods, but if the network request to the API server fails +for any reason, these pods will remain in the cluster. The executor processes should exit when they cannot reach the +driver, so the executor pods should not consume compute resources (cpu and memory) in the cluster after your application +exits. + +### Authentication Parameters + +Use the exact prefix `spark.kubernetes.authenticate` for Kubernetes authentication parameters in client mode. + ## Dependency Management If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to @@ -140,6 +179,42 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` +To use a secret through an environment variable use the following options to the `spark-submit` command: +``` +--conf spark.kubernetes.driver.secretKeyRef.ENV_NAME=name:key +--conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key +``` + +## Using Kubernetes Volumes + +Starting with Spark 2.4.0, users can mount the following types of Kubernetes [volumes](https://kubernetes.io/docs/concepts/storage/volumes/) into the driver and executor pods: +* [hostPath](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath): mounts a file or directory from the host node’s filesystem into a pod. +* [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir): an initially empty volume created when a pod is assigned to a node. +* [persistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/volumes/#persistentvolumeclaim): used to mount a `PersistentVolume` into a pod. + +To mount a volume of any of the types above into the driver pod, use the following configuration property: + +``` +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path= +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly= +``` + +Specifically, `VolumeType` can be one of the following values: `hostPath`, `emptyDir`, and `persistentVolumeClaim`. `VolumeName` is the name you want to use for the volume under the `volumes` field in the pod specification. + +Each supported type of volumes may have some specific configuration options, which can be specified using configuration properties of the following form: + +``` +spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName]= +``` + +For example, the claim name of a `persistentVolumeClaim` with volume name `checkpointpvc` can be specified using the following property: + +``` +spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=check-point-pvc-claim +``` + +The configuration properties for mounting volumes into the executor pods use prefix `spark.kubernetes.executor.` instead of `spark.kubernetes.driver.`. For a complete list of available options for each supported type of volumes, please refer to the [Spark Properties](#spark-properties) section below. + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -252,28 +327,17 @@ RBAC authorization and how to configure Kubernetes service accounts for pods, pl [Using RBAC Authorization](https://kubernetes.io/docs/admin/authorization/rbac/) and [Configure Service Accounts for Pods](https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/). -## Client Mode - -Client mode is not currently supported. - ## Future Work -There are several Spark on Kubernetes features that are currently being incubated in a fork - -[apache-spark-on-k8s/spark](https://github.com/apache-spark-on-k8s/spark), which are expected to eventually make it into -future versions of the spark-kubernetes integration. +There are several Spark on Kubernetes features that are currently being worked on or planned to be worked on. Those features are expected to eventually make it into future versions of the spark-kubernetes integration. Some of these include: -* PySpark -* R -* Dynamic Executor Scaling +* Dynamic Resource Allocation and External Shuffle Service * Local File Dependency Management * Spark Application Management * Job Queues and Resource Management -You can refer to the [documentation](https://apache-spark-on-k8s.github.io/userdocs/) if you want to try these features -and provide feedback to the development team. - # Configuration See the [configuration page](configuration.html) for information on Spark configurations. The following configurations are @@ -321,6 +385,13 @@ specific to Spark on Kubernetes. Container image pull policy used when pulling images within Kubernetes. + + spark.kubernetes.container.image.pullSecrets + + + Comma separated list of Kubernetes secrets used to pull images from private image registries. + + spark.kubernetes.allocation.batch.size 5 @@ -342,7 +413,7 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not provide - a scheme). + a scheme). In client mode, use spark.kubernetes.authenticate.caCertFile instead. @@ -351,7 +422,7 @@ specific to Spark on Kubernetes. Path to the client key file for authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not provide - a scheme). + a scheme). In client mode, use spark.kubernetes.authenticate.clientKeyFile instead. @@ -360,7 +431,7 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not - provide a scheme). + provide a scheme). In client mode, use spark.kubernetes.authenticate.clientCertFile instead. @@ -369,7 +440,7 @@ specific to Spark on Kubernetes. OAuth token to use when authenticating against the Kubernetes API server when starting the driver. Note that unlike the other authentication options, this is expected to be the exact string value of the token to use for - the authentication. + the authentication. In client mode, use spark.kubernetes.authenticate.oauthToken instead. @@ -378,7 +449,7 @@ specific to Spark on Kubernetes. Path to the OAuth token file containing the token to use when authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not - provide a scheme). + provide a scheme). In client mode, use spark.kubernetes.authenticate.oauthTokenFile instead. @@ -387,7 +458,8 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS from the driver pod when requesting executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.caCertFile instead. @@ -395,10 +467,9 @@ specific to Spark on Kubernetes. (none) Path to the client key file for authenticating against the Kubernetes API server from the driver pod when requesting - executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). If this is specified, it is highly - recommended to set up TLS for the driver submission server, as this value is sensitive information that would be - passed to the driver pod in plaintext otherwise. + executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod as + a Kubernetes secret. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + In client mode, use spark.kubernetes.authenticate.clientKeyFile instead. @@ -407,7 +478,8 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This file must be located on the submitting machine's disk, and will be uploaded to the - driver pod. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + driver pod as a Kubernetes secret. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + In client mode, use spark.kubernetes.authenticate.clientCertFile instead. @@ -416,9 +488,8 @@ specific to Spark on Kubernetes. OAuth token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. Note that unlike the other authentication options, this must be the exact string value of - the token to use for the authentication. This token value is uploaded to the driver pod. If this is specified, it is - highly recommended to set up TLS for the driver submission server, as this value is sensitive information that would - be passed to the driver pod in plaintext otherwise. + the token to use for the authentication. This token value is uploaded to the driver pod as a Kubernetes secret. + In client mode, use spark.kubernetes.authenticate.oauthToken instead. @@ -427,9 +498,8 @@ specific to Spark on Kubernetes. Path to the OAuth token file containing the token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. Note that unlike the other authentication options, this file must contain the exact string value of - the token to use for the authentication. This token value is uploaded to the driver pod. If this is specified, it is - highly recommended to set up TLS for the driver submission server, as this value is sensitive information that would - be passed to the driver pod in plaintext otherwise. + the token to use for the authentication. This token value is uploaded to the driver pod as a secret. In client mode, use + spark.kubernetes.authenticate.oauthTokenFile instead. @@ -438,7 +508,8 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.caCertFile instead. @@ -447,7 +518,8 @@ specific to Spark on Kubernetes. Path to the client key file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.clientKeyFile instead. @@ -456,7 +528,8 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.clientCertFile instead. @@ -465,7 +538,8 @@ specific to Spark on Kubernetes. Path to the file containing the OAuth token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Note that unlike the other authentication options, this file must contain the exact string value of the token to use for the authentication. + Note that unlike the other authentication options, this file must contain the exact string value of the token to use + for the authentication. In client mode, use spark.kubernetes.authenticate.oauthTokenFile instead. @@ -474,7 +548,48 @@ specific to Spark on Kubernetes. Service account that is used when running the driver pod. The driver pod uses this service account when requesting executor pods from the API server. Note that this cannot be specified alongside a CA cert file, client key file, - client cert file, and/or OAuth token. + client cert file, and/or OAuth token. In client mode, use spark.kubernetes.authenticate.serviceAccountName instead. + + + + spark.kubernetes.authenticate.caCertFile + (none) + + In client mode, path to the CA cert file for connecting to the Kubernetes API server over TLS when + requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.clientKeyFile + (none) + + In client mode, path to the client key file for authenticating against the Kubernetes API server + when requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.clientCertFile + (none) + + In client mode, path to the client cert file for authenticating against the Kubernetes API server + when requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.oauthToken + (none) + + In client mode, the OAuth token to use when authenticating against the Kubernetes API server when + requesting executors. Note that unlike the other authentication options, this must be the exact string value of + the token to use for the authentication. + + + + spark.kubernetes.authenticate.oauthTokenFile + (none) + + In client mode, path to the file containing the OAuth token to use when authenticating against the Kubernetes API + server when requesting executors. @@ -517,8 +632,11 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.pod.name (none) - Name of the driver pod. If not set, the driver pod name is set to "spark.app.name" suffixed by the current timestamp - to avoid name conflicts. + Name of the driver pod. In cluster mode, if this is not set, the driver pod name is set to "spark.app.name" + suffixed by the current timestamp to avoid name conflicts. In client mode, if your application is running + inside a pod, it is highly recommended to set this to the name of the pod your driver is running in. Setting this + value in client mode allows the driver to become the owner of its executor pods, which in turn allows the executor + pods to be garbage collected by the cluster. @@ -602,4 +720,83 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. + + spark.kubernetes.driver.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the driver container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.driver.secretKeyRef.ENV_VAR=spark-secret:key. + + + + spark.kubernetes.executor.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path + (none) + + Add the Kubernetes Volume named VolumeName of the VolumeType type to the driver pod on the path specified in the value. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly + (none) + + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName] + (none) + + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value, must conform with Kubernetes option format. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path + (none) + + Add the Kubernetes Volume named VolumeName of the VolumeType type to the executor pod on the path specified in the value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnly + false + + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].options.[OptionName] + (none) + + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + + + + spark.kubernetes.memoryOverheadFactor + 0.1 + + This sets the Memory Overhead Factor that will allocate memory to non-JVM memory, which includes off-heap memory allocations, non-JVM tasks, and various systems processes. For JVM-based jobs this value will default to 0.10 and 0.40 for non-JVM jobs. + This is done as non-JVM tasks need more non-JVM heap space and such tasks commonly fail with "Memory Overhead Exceeded" errors. This prempts this error with a higher default. + + + + spark.kubernetes.pyspark.pythonVersion + "2" + + This sets the major Python version of the docker image used to run the driver and executor containers. Can either be 2 or 3. + + diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 3c2a1501ca692..b473e654563d6 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -174,6 +174,8 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. +Note that the `MesosClusterDispatcher` does not support authentication. You should ensure that all network access to it is +protected (port 7077 by default). By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. @@ -670,7 +672,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.dispatcher.historyServer.url (none) - Set the URL of the history + Set the URL of the history server. The dispatcher will then link each driver to its entry in the history server. @@ -753,6 +755,18 @@ See the [configuration page](configuration.html) for information on Spark config spark.cores.max is reached + + spark.mesos.appJar.local.resolution.mode + host + + Provides support for the `local:///` scheme to reference the app jar resource in cluster mode. + If user uses a local resource (`local:///path/to/jar`) and the config option is not used it defaults to `host` eg. + the mesos fetcher tries to get the resource from the host's file system. + If the value is unknown it prints a warning msg in the dispatcher logs and defaults to `host`. + If the value is `container` then spark submit in the container will use the jar in the container's path: + `/path/to/jar`. + + # Troubleshooting and Debugging diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index ceda8a3ae2403..e3d67c34d53eb 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -61,7 +61,7 @@ In `cluster` mode, the driver runs on a different machine than the client, so `S # Preparations Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. +Binary distributions can be downloaded from the [downloads page](https://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). To make Spark runtime jars accessible from YARN side, you can specify `spark.yarn.archive` or `spark.yarn.jars`. For details please refer to [Spark Properties](running-on-yarn.html#spark-properties). If neither `spark.yarn.archive` nor `spark.yarn.jars` is specified, Spark will create a zip file with all jars under `$SPARK_HOME/jars` and upload it to the distributed cache. @@ -133,9 +133,8 @@ To use a custom metrics.properties for the application master and executors, upd spark.yarn.am.waitTime 100s - In cluster mode, time for the YARN Application Master to wait for the - SparkContext to be initialized. In client mode, time for the YARN Application Master to wait - for the driver to connect to it. + Only used in cluster mode. Time for the YARN Application Master to wait for the + SparkContext to be initialized. @@ -219,9 +218,10 @@ To use a custom metrics.properties for the application master and executors, upd spark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to + Comma-separated list of schemes for which resources will be downloaded to the local disk prior to being added to YARN's distributed cache. For use in cases where the YARN service does not - support schemes that are supported by Spark, like http, https and ftp. + support schemes that are supported by Spark, like http, https and ftp, or jars required to be in the + local YARN client's classpath. Wildcard '*' is denoted to download resources for all the schemes. @@ -412,6 +412,23 @@ To use a custom metrics.properties for the application master and executors, upd name matches both the include and the exclude pattern, this file will be excluded eventually. + + spark.yarn.blacklist.executor.launch.blacklisting.enabled + false + + Flag to enable blacklisting of nodes having YARN resource allocation problems. + The error limit for blacklisting can be configured by + spark.blacklist.application.maxFailedExecutorsPerNode. + + + + spark.yarn.metrics.namespace + (none) + + The root namespace for AM metrics reporting. + If it is not set then the YARN application ID is used. + + # Important notes @@ -425,9 +442,12 @@ To use a custom metrics.properties for the application master and executors, upd Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. -In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home -directory, Spark will also automatically obtain delegation tokens for the service hosting the -staging directory of the Spark application. +In YARN mode, when accessing Hadoop filesystems, Spark will automatically obtain delegation tokens +for: + +- the filesystem hosting the staging directory of the Spark application (which is the default + filesystem if `spark.yarn.stagingDir` is not set); +- if Hadoop federation is enabled, all the federated filesystems in the configuration. If an application needs to interact with other secure Hadoop filesystems, their URIs need to be explicitly provided to Spark at launch time. This is done by listing them in the diff --git a/docs/security.md b/docs/security.md index 8c0c66fb5a285..7fb3e17de94c9 100644 --- a/docs/security.md +++ b/docs/security.md @@ -22,7 +22,12 @@ secrets to be secure. For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. This secret will be shared by all the daemons and applications, so this deployment configuration is -not as secure as the above, especially when considering multi-tenant clusters. +not as secure as the above, especially when considering multi-tenant clusters. In this +configuration, a user with the secret can effectively impersonate any other user. + +The Rest Submission Server and the MesosClusterDispatcher do not support authentication. You should +ensure that all network access to the REST API & MesosClusterDispatcher (port 6066 and 7077 +respectively by default) are restricted to hosts that are trusted to submit jobs. @@ -44,7 +49,7 @@ not as secure as the above, especially when considering multi-tenant clusters. Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC authentication must also be enabled and properly configured. AES encryption uses the -[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's +[Apache Commons Crypto](https://commons.apache.org/proper/commons-crypto/) library, and Spark's configuration system allows access to that library's configuration for advanced users. There is also support for SASL-based encryption, although it should be considered deprecated. It @@ -164,7 +169,7 @@ The following settings cover enabling encryption for data written to disk: ## Authentication and Authorization -Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). +Enabling authentication for the Web UIs is done using [javax servlet filters](https://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). You will need a filter that implements the authentication method you want to deploy. Spark does not provide any built-in authentication filters. @@ -177,7 +182,7 @@ ACLs can be configured for either users or groups. Configuration entries accept lists as input, meaning multiple users or groups can be given the desired privileges. This can be used if you run on a shared cluster and have a set of administrators or developers who need to monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL -means that all users will have the respective pivilege. By default, only the user submitting the +means that all users will have the respective privilege. By default, only the user submitting the application is added to the ACLs. Group membership is established by using a configurable group mapping provider. The mapper is @@ -278,7 +283,7 @@ To enable authorization in the SHS, a few extra options are used:
    Property NameDefaultMeaning
    - + - + - +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablespark.history.ui.acls.enable false Specifies whether ACLs should be checked to authorize users viewing the applications in @@ -292,7 +297,7 @@ To enable authorization in the SHS, a few extra options are used:
    spark.history.ui.admin.aclsspark.history.ui.admin.acls None Comma separated list of users that have view access to all the Spark applications in history @@ -300,7 +305,7 @@ To enable authorization in the SHS, a few extra options are used:
    spark.history.ui.admin.acls.groupsspark.history.ui.admin.acls.groups None Comma separated list of groups that have view access to all the Spark applications in history @@ -446,6 +451,27 @@ replaced with one of the above namespaces.
    +Spark also supports retrieving `${ns}.keyPassword`, `${ns}.keyStorePassword` and `${ns}.trustStorePassword` from +[Hadoop Credential Providers](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/CredentialProviderAPI.html). +User could store password into credential file and make it accessible by different components, like: + +``` +hadoop credential create spark.ssl.keyPassword -value password \ + -provider jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks +``` + +To configure the location of the credential provider, set the `hadoop.security.credential.provider.path` +config option in the Hadoop configuration used by Spark, like: + +``` + + hadoop.security.credential.provider.path + jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks + +``` + +Or via SparkConf "spark.hadoop.hadoop.security.credential.provider.path=jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks". + ## Preparing the key stores Key stores can be generated by `keytool` program. The reference documentation for this tool for @@ -466,7 +492,7 @@ distributed with the application using the `--files` command line argument (or t configuration should just reference the file name with no absolute path. Distributing local key stores this way may require the files to be staged in HDFS (or other similar -distributed file system used by the cluster), so it's recommended that the undelying file system be +distributed file system used by the cluster), so it's recommended that the underlying file system be configured with security in mind (e.g. by enabling authentication and wire encryption). ### Standalone mode @@ -480,6 +506,7 @@ can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that c provided by the user on the client side are not used. ### Mesos mode + Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. @@ -541,8 +568,12 @@ Security. # Configuring Ports for Network Security -Spark makes heavy use of the network, and some environments have strict requirements for using tight -firewall settings. Below are the primary ports that Spark uses for its communication and how to +Generally speaking, a Spark cluster and its services are not deployed on the public internet. +They are generally private services, and should only be accessible within the network of the +organization that deploys Spark. Access to the hosts and ports used by Spark services should +be limited to origin hosts that need to access the services. + +Below are the primary ports that Spark uses for its communication and how to configure those ports. ## Standalone mode only @@ -576,6 +607,14 @@ configure those ports. SPARK_MASTER_PORT Set to "0" to choose a port randomly. Standalone mode only. + + External Service + Standalone Master + 6066 + Submit job to cluster via REST API + spark.master.rest.port + Use spark.master.rest.enabled to enable/disable this service. Standalone mode only. + Standalone Master Standalone Worker diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f06e72a387df1..7975b0c8b11ca 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -254,6 +254,18 @@ SPARK_WORKER_OPTS supports the following system properties: especially if you run jobs very frequently. + + spark.storage.cleanupFilesAfterExecutorExit + true + + Enable cleanup non-shuffle files(such as temp. shuffle blocks, cached RDD/broadcast blocks, + spill files, etc) of worker directories following executor exits. Note that this doesn't + overlap with `spark.worker.cleanup.enabled`, as this enables cleanup of non-shuffle files in + local directories of a dead executor, while `spark.worker.cleanup.enabled` enables cleanup of + all files/subdirectories of a stopped and timeout application. + This only affects Standalone mode, support of other cluster manangers can be added in the future. + + spark.worker.ui.compressedLogFileLengthCacheSize 100 @@ -350,8 +362,15 @@ You can run Spark alongside your existing Hadoop cluster by just launching it as # Configuring Ports for Network Security -Spark makes heavy use of the network, and some environments have strict requirements for using -tight firewall settings. For a complete list of ports to configure, see the +Generally speaking, a Spark cluster and its services are not deployed on the public internet. +They are generally private services, and should only be accessible within the network of the +organization that deploys Spark. Access to the hosts and ports used by Spark services should +be limited to origin hosts that need to access the services. + +This is particularly important for clusters using the standalone resource manager, as they do +not support fine-grained access control in a way that other resource managers do. + +For a complete list of ports to configure, see the [security page](security.html#configuring-ports-for-network-security). # High Availability @@ -364,7 +383,7 @@ By default, standalone scheduling clusters are resilient to Worker failures (ins Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected. -Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/current/zookeeperStarted.html). +Learn more about getting started with ZooKeeper [here](https://zookeeper.apache.org/doc/current/zookeeperStarted.html). **Configuration** @@ -407,6 +426,6 @@ In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spa **Details** -* This solution can be used in tandem with a process monitor/manager like [monit](http://mmonit.com/monit/), or just to enable manual recovery via restart. +* This solution can be used in tandem with a process monitor/manager like [monit](https://mmonit.com/monit/), or just to enable manual recovery via restart. * While filesystem recovery seems straightforwardly better than not doing any recovery at all, this mode may be suboptimal for certain development or experimental purposes. In particular, killing a master via stop-master.sh does not clean up its recovery state, so whenever you start a new Master, it will enter recovery mode. This could increase the startup time by up to 1 minute if it needs to wait for all previously-registered Workers/clients to timeout. * While it's not officially supported, you could mount an NFS directory as the recovery directory. If the original Master node dies completely, you could then start a Master on a different node, which would correctly recover all previously registered Workers/applications (equivalent to ZooKeeper recovery). Future applications will have to be able to find the new Master, however, in order to register. diff --git a/docs/sparkr.md b/docs/sparkr.md index 7fabab5d38f16..b4248e8bb21de 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -128,7 +128,7 @@ head(df) SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. -SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](http://spark.apache.org/third-party-projects.html), you can find data source connectors for popular file formats like Avro. These packages can either be added by +SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](https://spark.apache.org/third-party-projects.html), you can find data source connectors for popular file formats like Avro. These packages can either be added by specifying `--packages` with `spark-submit` or `sparkR` commands, or if initializing SparkSession with `sparkPackages` parameter when in an interactive R shell or from RStudio.
    @@ -664,6 +664,10 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. -## Upgrading to Spark 2.4.0 +## Upgrading to SparkR 2.3.1 and above - - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. + - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. + +## Upgrading to SparkR 2.4.0 + + - Previously, we don't check the validity of the size of the last layer in `spark.mlp`. For example, if the training data only has two labels, a `layers` param like `c(1, 3)` doesn't cause an error previously, now it does. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e8ff1470970f7..3749094569271 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -964,7 +964,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession Sets the compression codec used when writing Parquet files. If either `compression` or `parquet.compression` is specified in the table-specific options/properties, the precedence would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: - none, uncompressed, snappy, gzip, lzo. + none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd. @@ -1017,7 +1017,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also Property NameDefaultMeaning spark.sql.orc.impl - hive + native The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.2. + options are 0.12.0 through 2.3.3. @@ -1302,9 +1302,33 @@ the following case-insensitive options: dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a - subquery in parentheses. + The JDBC table that should be read from or written into. Note that when using it in the read + path anything that is valid in a FROM clause of a SQL query can be used. + For example, instead of a full table you could also use a subquery in parentheses. It is not + allowed to specify `dbtable` and `query` options at the same time. + + + + query + + A query that will be used to read data into Spark. The specified query will be parenthesized and used + as a subquery in the FROM clause. Spark will also assign an alias to the subquery clause. + As an example, spark will issue a query of the following form to the JDBC Source.

    + SELECT <columns> FROM (<user_specified_query>) spark_gen_alias

    + Below are couple of restrictions while using this option.
    +
      +
    1. It is not allowed to specify `dbtable` and `query` options at the same time.
    2. +
    3. It is not allowed to spcify `query` and `partitionColumn` options at the same time. When specifying + `partitionColumn` option is required, the subquery can be specified using `dbtable` option instead and + partition columns can be qualified using the subquery alias provided as part of `dbtable`.
      + Example:
      + + spark.read.format("jdbc")
      +    .option("dbtable", "(select c1, c2 from t1) as subq")
      +    .option("partitionColumn", "subq.c1"
      +    .load() +
    4. +
    @@ -1321,8 +1345,8 @@ the following case-insensitive options: These options must all be specified if any of them is specified. In addition, numPartitions must be specified. They describe how to partition the table when reading in parallel from multiple workers. - partitionColumn must be a numeric column from the table in question. Notice - that lowerBound and upperBound are just used to decide the + partitionColumn must be a numeric, date, or timestamp column from the table in question. + Notice that lowerBound and upperBound are just used to decide the partition stride, not for filtering the rows in table. So all rows in the table will be partitioned and returned. This option applies only to reading. @@ -1338,6 +1362,17 @@ the following case-insensitive options: + + queryTimeout + + The number of seconds the driver will wait for a Statement object to execute to the given + number of seconds. Zero means there is no limit. In the write path, this option depends on + how JDBC drivers implement the API setQueryTimeout, e.g., the h2 JDBC driver + checks the timeout of each query instead of an entire JDBC batch. + It defaults to 0. + + + fetchsize @@ -1372,6 +1407,13 @@ the following case-insensitive options: This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g., indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. This option applies only to writing. + + + cascadeTruncate + + This is a JDBC writer related option. If enabled and supported by the JDBC database (PostgreSQL and Oracle at the moment), this options allows execution of a TRUNCATE TABLE t CASCADE (in the case of PostgreSQL a TRUNCATE TABLE ONLY t CASCADE is executed to prevent inadvertently truncating descendant tables). This will affect other tables, and thus should be used with care. This option applies only to writing. It defaults to the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect. + + createTableOptions @@ -1393,6 +1435,13 @@ the following case-insensitive options: The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING". You can also specify partial fields, and the others use the default type mapping. For example, "id DECIMAL(38, 0)". The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading. + + + pushDownPredicate + + The option to enable or disable predicate push-down into the JDBC data source. The default value is true, in which case Spark will push down filters to the JDBC data source as much as possible. Otherwise, if set to false, no filter will be pushed down to the JDBC data source and thus all filters will be handled by Spark. Predicate push-down is usually turned off when the predicate filtering is performed faster by Spark than by the JDBC data source. + +
    @@ -1433,6 +1482,9 @@ SELECT * FROM resultTable
    +## Avro Files +See the [Apache Avro Data Source Guide](avro-data-source-guide.html). + ## Troubleshooting * The JDBC driver class must be visible to the primordial class loader on the client session and on all executors. This is because Java's DriverManager class does a security check that results in it ignoring all drivers not visible to the primordial class loader when one goes to open a connection. One convenient way to do this is to modify compute_classpath.sh on all worker nodes to include your driver JARs. @@ -1703,7 +1755,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) @@ -1741,8 +1793,13 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. + Note that all data for a group will be loaded into memory before the function is applied. This can -lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for +lead to out of memory exceptions, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user to ensure that the grouped data will fit into the available memory. @@ -1757,6 +1814,25 @@ The following example shows how to use `groupby().apply()` to subtract the mean For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and [`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). +### Grouped Aggregate + +Grouped aggregate Pandas UDFs are similar to Spark aggregate functions. Grouped aggregate Pandas UDFs are used with `groupBy().agg()` and +[`pyspark.sql.Window`](api/python/pyspark.sql.html#pyspark.sql.Window). It defines an aggregation from one or more `pandas.Series` +to a scalar value, where each `pandas.Series` represents a column within the group or window. + +Note that this type of UDF does not support partial aggregation and all data for a group or window will be loaded into memory. Also, +only unbounded window is supported with Grouped aggregate Pandas UDFs currently. + +The following example shows how to use this type of UDF to compute mean with groupBy and window operations: + +
    +
    +{% include_example grouped_agg_pandas_udf python/sql/arrow.py %} +
    +
    + +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) + ## Usage Notes ### Supported SQL Types @@ -1803,14 +1879,29 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 + - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. + - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970. - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behavior to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. + - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. + - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. + - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.parallelFileListingInStatsComputation.enabled` to `False`. + - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. + +## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above + + - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. @@ -1967,6 +2058,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. + - Un-aliased subquery's semantic has not been well defined with confusing behaviors. Since Spark 2.3, we invalidate such confusing cases, for example: `SELECT v.i from (SELECT i FROM v)`, Spark will throw an analysis exception in this case because users should not be able to use the qualifier inside a subquery. See [SPARK-20690](https://issues.apache.org/jira/browse/SPARK-20690) and [SPARK-21335](https://issues.apache.org/jira/browse/SPARK-21335) for more details. ## Upgrading From Spark SQL 2.1 to 2.2 @@ -2073,7 +2165,7 @@ See the API docs for `SQLContext.read` ( Python ) and `DataFrame.write` ( Scala, - Java, + Java, Python ) more information. @@ -2233,7 +2325,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses @@ -2986,3 +3078,10 @@ Specifically: - In aggregations, all NaN values are grouped together. - NaN is treated as a normal value in join keys. - NaN values go last when in ascending order, larger than any other numeric value. + + ## Arithmetic operations + +Operations performed on numeric types (with the exception of `decimal`) are not checked for overflow. +This means that in case an operation causes an overflow, the result is the same that the same operation +returns in a Java/Scala program (eg. if the sum of 2 integers is higher than the maximum value representable, +the result is a negative number). diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 678b0643fd706..6a52e8a7b0ebd 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -196,7 +196,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Running the Example To run the example, -- Download a Spark binary from the [download site](http://spark.apache.org/downloads.html). +- Download a Spark binary from the [download site](https://spark.apache.org/downloads.html). - Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c30959263cdfa..0ca0f2a8b54d5 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -915,8 +915,7 @@ JavaPairDStream runningCounts = pairs.updateStateByKey(updateFu The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Java code, take a look at the example -[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming -/JavaStatefulNetworkWordCount.java). +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java).
    @@ -2176,6 +2175,8 @@ the input data stream (using `inputStream.repartition()`). This distributes the received batches of data across the specified number of machines in the cluster before further processing. +For direct stream, please refer to [Spark Streaming + Kafka Integration Guide](streaming-kafka-integration.html) + ### Level of Parallelism in Data Processing {:.no_toc} Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the @@ -2468,7 +2469,7 @@ additional effort may be necessary to achieve exactly-once semantics. There are - [Kafka Integration Guide](streaming-kafka-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) -* Third-party DStream data sources can be found in [Third Party Projects](http://spark.apache.org/third-party-projects.html) +* Third-party DStream data sources can be found in [Third Party Projects](https://spark.apache.org/third-party-projects.html) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 602a4c70848e7..73de1892977ac 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -17,7 +17,7 @@ In this guide, we are going to walk you through the programming model and the AP # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in [Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py)/[R]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming/structured_network_wordcount.R). -And if you [download Spark](http://spark.apache.org/downloads.html), you can directly [run the example](index.html#running-the-examples-and-shell). In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark. +And if you [download Spark](https://spark.apache.org/downloads.html), you can directly [run the example](index.html#running-the-examples-and-shell). In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
    @@ -522,7 +522,7 @@ Here are the details of all the sources in Spark.
    maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max)
    - latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) + latestFirst: whether to process the latest new files first, useful when there is a large backlog of files (default: false)
    fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same:
    @@ -926,7 +926,7 @@ event time. For a specific window starting at time `T`, the engine will maintain data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, but data later than the threshold will start getting dropped -(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +(see [later](#semantic-guarantees-of-aggregation-with-watermarking) in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below. @@ -1005,7 +1005,7 @@ Here is an illustration. As shown in the illustration, the maximum event time tracked by the engine is the *blue dashed line*, and the watermark set as `(max event time - '10 mins')` -at the beginning of every trigger is the red line For example, when the engine observes the data +at the beginning of every trigger is the red line. For example, when the engine observes the data `(12:14, dog)`, it sets the watermark for the next trigger as `12:04`. This watermark lets the engine maintain intermediate state for additional 10 minutes to allow late data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in @@ -1162,7 +1162,7 @@ In other words, you will have to do the following additional steps in the join. old rows of one input is not going to be required (i.e. will not satisfy the time constraint) for matches with the other input. This constraint can be defined in one of the two ways. - 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEN rightTime AND rightTime + INTERVAL 1 HOUR`), + 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEEN rightTime AND rightTime + INTERVAL 1 HOUR`), 1. Join on event-time windows (e.g. `...JOIN ON leftTimeWindow = rightTimeWindow`). diff --git a/docs/tuning.md b/docs/tuning.md index 912c39879be8f..1c3bd0e8758ff 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -132,7 +132,7 @@ The best way to size the amount of memory consumption a dataset will require is into cache, and look at the "Storage" page in the web UI. The page will tell you how much memory the RDD is occupying. -To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method +To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method. This is useful for experimenting with different data layouts to trim memory usage, as well as determining the amount of space a broadcast variable will occupy on each executor heap. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java new file mode 100644 index 0000000000000..51865637df6f6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.clustering.PowerIterationClustering; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPowerIterationClusteringExample { + public static void main(String[] args) { + // Create a SparkSession. + SparkSession spark = SparkSession + .builder() + .appName("JavaPowerIterationClustering") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(0L, 1L, 1.0), + RowFactory.create(0L, 2L, 1.0), + RowFactory.create(1L, 2L, 1.0), + RowFactory.create(3L, 4L, 1.0), + RowFactory.create(4L, 0L, 0.1) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("src", DataTypes.LongType, false, Metadata.empty()), + new StructField("dst", DataTypes.LongType, false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + PowerIterationClustering model = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .setWeightCol("weight"); + + Dataset result = model.assignClusters(df); + result.show(false); + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java new file mode 100644 index 0000000000000..e9b84365d86ed --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.*; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.stat.Summarizer; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaSummarizerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaSummarizerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Vectors.dense(2.0, 3.0, 5.0), 1.0), + RowFactory.create(Vectors.dense(4.0, 6.0, 7.0), 2.0) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + Row result1 = df.select(Summarizer.metrics("mean", "variance") + .summary(new Column("features"), new Column("weight")).as("summary")) + .select("summary.mean", "summary.variance").first(); + System.out.println("with weight: mean = " + result1.getAs(0).toString() + + ", variance = " + result1.getAs(1).toString()); + + Row result2 = df.select( + Summarizer.mean(new Column("features")), + Summarizer.variance(new Column("features")) + ).first(); + System.out.println("without weight: mean = " + result2.getAs(0).toString() + + ", variance = " + result2.getAs(1).toString()); + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java index b48b95ff1d2a3..273273652c955 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java @@ -67,7 +67,7 @@ public static void main(String[] args) { ) ); - // The contingency table is constructed from the raw (feature, label) pairs and used to conduct + // The contingency table is constructed from the raw (label, feature) pairs and used to conduct // the independence test. Returns an array containing the ChiSquaredTestResult for every feature // against the label. ChiSqTestResult[] featureTestResults = Statistics.chiSqTest(obs.rdd()); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index b6b163fa8b2cd..748bf58f30350 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -26,7 +26,9 @@ import scala.Tuple2; +import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.*; @@ -37,30 +39,33 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: JavaDirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ public final class JavaDirectKafkaWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaDirectKafkaWordCount \n" + - " is a list of one or more Kafka brokers\n" + - " is a list of one or more kafka topics to consume from\n\n"); + if (args.length < 3) { + System.err.println("Usage: JavaDirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a consumer group name to consume from topics\n" + + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); } StreamingExamples.setStreamingLogLevels(); String brokers = args[0]; - String topics = args[1]; + String groupId = args[1]; + String topics = args[2]; // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); @@ -68,7 +73,10 @@ public static void main(String[] args) throws Exception { Set topicsSet = new HashSet<>(Arrays.asList(topics.split(","))); Map kafkaParams = new HashMap<>(); - kafkaParams.put("metadata.broker.list", brokers); + kafkaParams.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers); + kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + kafkaParams.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + kafkaParams.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); // Create direct kafka stream with brokers and topics JavaInputDStream> messages = KafkaUtils.createDirectStream( diff --git a/examples/src/main/python/ml/summarizer_example.py b/examples/src/main/python/ml/summarizer_example.py new file mode 100644 index 0000000000000..8835f189a1ad4 --- /dev/null +++ b/examples/src/main/python/ml/summarizer_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for summarizer. +Run with: + bin/spark-submit examples/src/main/python/ml/summarizer_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.stat import Summarizer +from pyspark.sql import Row +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("SummarizerExample") \ + .getOrCreate() + sc = spark.sparkContext + + # $example on$ + df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + + # create summarizer for multiple metrics "mean" and "count" + summarizer = Summarizer.metrics("mean", "count") + + # compute statistics for multiple metrics with weight + df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + + # compute statistics for multiple metrics without weight + df.select(summarizer.summary(df.features)).show(truncate=False) + + # compute statistics for single metric "mean" with weight + df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + + # compute statistics for single metric "mean" without weight + df.select(Summarizer.mean(df.features)).show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/mllib/hypothesis_testing_example.py b/examples/src/main/python/mllib/hypothesis_testing_example.py index e566ead0d318d..21a5584fd6e06 100644 --- a/examples/src/main/python/mllib/hypothesis_testing_example.py +++ b/examples/src/main/python/mllib/hypothesis_testing_example.py @@ -51,7 +51,7 @@ [LabeledPoint(1.0, [1.0, 0.0, 3.0]), LabeledPoint(1.0, [1.0, 2.0, 0.0]), LabeledPoint(1.0, [-1.0, 0.0, -0.5])] - ) # LabeledPoint(feature, label) + ) # LabeledPoint(label, feature) # The contingency table is constructed from an RDD of LabeledPoint and used to conduct # the independence test. Returns an array containing the ChiSquaredTestResult for every feature diff --git a/examples/src/main/python/py_container_checks.py b/examples/src/main/python/py_container_checks.py new file mode 100644 index 0000000000000..f6b3be2806c82 --- /dev/null +++ b/examples/src/main/python/py_container_checks.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys + + +def version_check(python_env, major_python_version): + """ + These are various tests to test the Python container image. + This file will be distributed via --py-files in the e2e tests. + """ + env_version = os.environ.get('PYSPARK_PYTHON') + print("Python runtime version check is: " + + str(sys.version_info[0] == major_python_version)) + + print("Python environment version check is: " + + str(env_version == python_env)) diff --git a/examples/src/main/python/pyfiles.py b/examples/src/main/python/pyfiles.py new file mode 100644 index 0000000000000..4193654b49a12 --- /dev/null +++ b/examples/src/main/python/pyfiles.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: pyfiles [major_python_version] + """ + spark = SparkSession \ + .builder \ + .appName("PyFilesTest") \ + .getOrCreate() + + from py_container_checks import version_check + # Begin of Python container checks + version_check(sys.argv[1], 2 if sys.argv[1] == "python" else 3) + + spark.stop() diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 4c5aefb6ff4a6..5eb164b20ad04 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -95,12 +95,12 @@ def grouped_map_pandas_udf_example(spark): ("id", "v")) @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) - def substract_mean(pdf): + def subtract_mean(pdf): # pdf is a pandas.DataFrame v = pdf.v return pdf.assign(v=v - v.mean()) - df.groupby("id").apply(substract_mean).show() + df.groupby("id").apply(subtract_mean).show() # +---+----+ # | id| v| # +---+----+ @@ -113,6 +113,43 @@ def substract_mean(pdf): # $example off:grouped_map_pandas_udf$ +def grouped_agg_pandas_udf_example(spark): + # $example on:grouped_agg_pandas_udf$ + from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql import Window + + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def mean_udf(v): + return v.mean() + + df.groupby("id").agg(mean_udf(df['v'])).show() + # +---+-----------+ + # | id|mean_udf(v)| + # +---+-----------+ + # | 1| 1.5| + # | 2| 6.0| + # +---+-----------+ + + w = Window \ + .partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() + # +---+----+------+ + # | id| v|mean_v| + # +---+----+------+ + # | 1| 1.0| 1.5| + # | 1| 2.0| 1.5| + # | 2| 3.0| 6.0| + # | 2| 5.0| 6.0| + # | 2|10.0| 6.0| + # +---+----+------+ + # $example off:grouped_agg_pandas_udf$ + + if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala new file mode 100644 index 0000000000000..ca8f7affb14e8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.clustering.PowerIterationClustering +// $example off$ +import org.apache.spark.sql.SparkSession + +object PowerIterationClusteringExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + + // $example on$ + val dataset = spark.createDataFrame(Seq( + (0L, 1L, 1.0), + (0L, 2L, 1.0), + (1L, 2L, 1.0), + (3L, 4L, 1.0), + (4L, 0L, 0.1) + )).toDF("src", "dst", "weight") + + val model = new PowerIterationClustering(). + setK(2). + setMaxIter(20). + setInitMode("degree"). + setWeightCol("weight") + + val prediction = model.assignClusters(dataset).select("id", "cluster") + + // Shows the cluster assignment + prediction.show(false) + // $example off$ + + spark.stop() + } + } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala new file mode 100644 index 0000000000000..2f54d1d81bc48 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.stat.Summarizer +// $example off$ +import org.apache.spark.sql.SparkSession + +object SummarizerExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("SummarizerExample") + .getOrCreate() + + import spark.implicits._ + import Summarizer._ + + // $example on$ + val data = Seq( + (Vectors.dense(2.0, 3.0, 5.0), 1.0), + (Vectors.dense(4.0, 6.0, 7.0), 2.0) + ) + + val df = data.toDF("features", "weight") + + val (meanVal, varianceVal) = df.select(metrics("mean", "variance") + .summary($"features", $"weight").as("summary")) + .select("summary.mean", "summary.variance") + .as[(Vector, Vector)].first() + + println(s"with weight: mean = ${meanVal}, variance = ${varianceVal}") + + val (meanVal2, varianceVal2) = df.select(mean($"features"), variance($"features")) + .as[(Vector, Vector)].first() + + println(s"without weight: mean = ${meanVal2}, sum = ${varianceVal2}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index add1719739539..9b3c3266ee30a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -61,9 +61,9 @@ object HypothesisTestingExample { LabeledPoint(-1.0, Vectors.dense(-1.0, 0.0, -0.5) ) ) - ) // (feature, label) pairs. + ) // (label, feature) pairs. - // The contingency table is constructed from the raw (feature, label) pairs and used to conduct + // The contingency table is constructed from the raw (label, feature) pairs and used to conduct // the independence test. Returns an array containing the ChiSquaredTestResult for every feature // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) diff --git a/external/avro/pom.xml b/external/avro/pom.xml new file mode 100644 index 0000000000000..8f118ba48201b --- /dev/null +++ b/external/avro/pom.xml @@ -0,0 +1,78 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../pom.xml + + + spark-avro_2.11 + + avro + + jar + Spark Avro + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..95835f0d4ca49 --- /dev/null +++ b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.avro.AvroFileFormat diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala new file mode 100644 index 0000000000000..915769fa708b0 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericDatumReader +import org.apache.avro.io.{BinaryDecoder, DecoderFactory} + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} + +case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType + + override def nullable: Boolean = true + + @transient private lazy val avroSchema = new Schema.Parser().parse(jsonFormatSchema) + + @transient private lazy val reader = new GenericDatumReader[Any](avroSchema) + + @transient private lazy val deserializer = new AvroDeserializer(avroSchema, dataType) + + @transient private var decoder: BinaryDecoder = _ + + @transient private var result: Any = _ + + override def nullSafeEval(input: Any): Any = { + val binary = input.asInstanceOf[Array[Byte]] + decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder) + result = reader.read(result, decoder) + deserializer.deserialize(result) + } + + override def simpleString: String = { + s"from_avro(${child.sql}, ${dataType.simpleString})" + } + + override def sql: String = { + s"from_avro(${child.sql}, ${dataType.catalogString})" + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)") + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala new file mode 100644 index 0000000000000..272e7d5b388d9 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.math.{BigDecimal} +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + */ +class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { + private lazy val decimalConversions = new DecimalConversion() + + private val converter: Any => Any = rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (data: Any) => InternalRow.empty + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val writer = getRecordWriter(rootAvroType, st, Nil) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + writer(fieldUpdater, record) + resultRow + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + tmpRow.get(0, rootCatalystType) + } + } + + def deserialize(data: Any): Any = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter( + avroType: Schema, + catalystType: DataType, + path: List[String]): (CatalystDataUpdater, Int, Any) => Unit = + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (INT, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => avroType.getLogicalType match { + case _: TimestampMillis => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case _: TimestampMicros => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + case null => (updater, ordinal, value) => + // For backward compatibility, if the Avro type is Long and it is not logical type, + // the value is processed as timestamp type with millisecond precision. + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case other => throw new IncompatibleSchemaException( + s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.") + } + + // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date. + // For backward compatibility, we still keep this conversion. + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + bytes + case b: Array[Byte] => b + case other => throw new RuntimeException(s"$other is not a valid avro binary.") + } + updater.set(ordinal, bytes) + + case (FIXED, d: DecimalType) => (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, + LogicalTypes.decimal(d.precision, d.scale)) + val decimal = createDecimal(bigDecimal, d.precision, d.scale) + updater.setDecimal(ordinal, decimal) + + case (BYTES, d: DecimalType) => (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, + LogicalTypes.decimal(d.precision, d.scale)) + val decimal = createDecimal(bigDecimal, d.precision, d.scale) + updater.setDecimal(ordinal, decimal) + + case (RECORD, st: StructType) => + val writeRecord = getRecordWriter(avroType, st, path) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val elementWriter = newWriter(avroType.getElementType, elementType, path) + (updater, ordinal, value) => + val array = value.asInstanceOf[GenericData.Array[Any]] + val len = array.size() + val result = createArrayData(elementType, len) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + while (i < len) { + val element = array.get(i) + if (element == null) { + if (!containsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path) + val valueWriter = newWriter(avroType.getValueType, valueType, path) + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, path) + } else { + nonNullTypes.map(_.getType) match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => newWriter(schema, field.dataType, path :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(avroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + } + } + } else { + (updater, ordinal, value) => updater.setNullAt(ordinal) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " + + s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + + // TODO: move the following method in Decimal object on creating Decimal from BigDecimal? + private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + Decimal(decimal.unscaledValue().longValue(), precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(decimal, precision, scale) + } + } + + private def getRecordWriter( + avroType: Schema, + sqlType: StructType, + path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = { + val validFieldIndexes = ArrayBuffer.empty[Int] + val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] + + val length = sqlType.length + var i = 0 + while (i < length) { + val sqlField = sqlType.fields(i) + val avroField = avroType.getField(sqlField.name) + if (avroField != null) { + validFieldIndexes += avroField.pos() + + val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) + val ordinal = i + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + fieldWriters += fieldWriter + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s""" + |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema. + |Source Avro schema: $rootAvroType. + |Target Catalyst type: $rootCatalystType. + """.stripMargin) + } + i += 1 + } + + (fieldUpdater, record) => { + var i = 0 + while (i < validFieldIndexes.length) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + i += 1 + } + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala new file mode 100755 index 0000000000000..6df23c93e4c54 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URI + +import scala.util.control.NonFatal + +import org.apache.avro.Schema +import org.apache.avro.file.DataFileConstants._ +import org.apache.avro.file.DataFileReader +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.{AvroOutputFormat, FsInput} +import org.apache.avro.mapreduce.AvroJob +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +private[avro] class AvroFileFormat extends FileFormat + with DataSourceRegister with Logging with Serializable { + + override def equals(other: Any): Boolean = other match { + case _: AvroFileFormat => true + case _ => false + } + + // Dummy hashCode() to appease ScalaStyle. + override def hashCode(): Int = super.hashCode() + + override def inferSchema( + spark: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val conf = spark.sessionState.newHadoopConf() + val parsedOptions = new AvroOptions(options, conf) + + // Schema evolution is not supported yet. Here we only pick a single random sample file to + // figure out the schema of the whole dataset. + val sampleFile = + if (parsedOptions.ignoreExtension) { + files.headOption.getOrElse { + throw new FileNotFoundException("Files for schema inferring have been not found.") + } + } else { + files.find(_.getPath.getName.endsWith(".avro")).getOrElse { + throw new FileNotFoundException( + "No Avro files found. If files don't have .avro extension, set ignoreExtension to true") + } + } + + // User can specify an optional avro json schema. + val avroSchema = parsedOptions.schema + .map(new Schema.Parser().parse) + .getOrElse { + val in = new FsInput(sampleFile.getPath, conf) + try { + val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) + try { + reader.getSchema + } finally { + reader.close() + } + } finally { + in.close() + } + } + + SchemaConverters.toSqlType(avroSchema).dataType match { + case t: StructType => Some(t) + case _ => throw new RuntimeException( + s"""Avro schema cannot be converted to a Spark SQL StructType: + | + |${avroSchema.toString(true)} + |""".stripMargin) + } + } + + override def shortName(): String = "avro" + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = true + + override def prepareWrite( + spark: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf()) + val outputAvroSchema: Schema = parsedOptions.schema + .map(new Schema.Parser().parse) + .getOrElse(SchemaConverters.toAvroType(dataSchema, nullable = false, + parsedOptions.recordName, parsedOptions.recordNamespace)) + + AvroJob.setOutputKeySchema(job, outputAvroSchema) + + if (parsedOptions.compression == "uncompressed") { + job.getConfiguration.setBoolean("mapred.output.compress", false) + } else { + job.getConfiguration.setBoolean("mapred.output.compress", true) + logInfo(s"Compressing Avro output using the ${parsedOptions.compression} codec") + val codec = parsedOptions.compression match { + case DEFLATE_CODEC => + val deflateLevel = spark.sessionState.conf.avroDeflateLevel + logInfo(s"Avro compression level $deflateLevel will be used for $DEFLATE_CODEC codec.") + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + DEFLATE_CODEC + case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC) => codec + case unknown => throw new IllegalArgumentException(s"Invalid compression codec: $unknown") + } + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, codec) + } + + new AvroOutputWriterFactory(dataSchema, outputAvroSchema.toString) + } + + override def buildReader( + spark: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + + val broadcastedConf = + spark.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val parsedOptions = new AvroOptions(options, hadoopConf) + + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + val userProvidedSchema = parsedOptions.schema.map(new Schema.Parser().parse) + + // TODO Removes this check once `FileFormat` gets a general file filtering interface method. + // Doing input file filtering is improper because we may generate empty tasks that process no + // input files but stress the scheduler. We should probably add a more general input file + // filtering mechanism for `FileFormat` data sources. See SPARK-16317. + if (parsedOptions.ignoreExtension || file.filePath.endsWith(".avro")) { + val reader = { + val in = new FsInput(new Path(new URI(file.filePath)), conf) + try { + val datumReader = userProvidedSchema match { + case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) + case _ => new GenericDatumReader[GenericRecord]() + } + DataFileReader.openReader(in, datumReader) + } catch { + case NonFatal(e) => + logError("Exception while opening DataFileReader", e) + in.close() + throw e + } + } + + // Ensure that the reader is closed even if the task fails or doesn't consume the entire + // iterator of records. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener[Unit] { _ => + reader.close() + } + } + + reader.sync(file.start) + val stop = file.start + file.length + + val deserializer = + new AvroDeserializer(userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) + + new Iterator[InternalRow] { + private[this] var completed = false + + override def hasNext: Boolean = { + if (completed) { + false + } else { + val r = reader.hasNext && !reader.pastSync(stop) + if (!r) { + reader.close() + completed = true + } + r + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("next on empty iterator") + } + val record = reader.next() + deserializer.deserialize(record).asInstanceOf[InternalRow] + } + } + } else { + Iterator.empty + } + } + } +} + +private[avro] object AvroFileFormat { + val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala new file mode 100644 index 0000000000000..67f56343b4524 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.internal.SQLConf + +/** + * Options for Avro Reader and Writer stored in case insensitive manner. + */ +class AvroOptions( + @transient val parameters: CaseInsensitiveMap[String], + @transient val conf: Configuration) extends Logging with Serializable { + + def this(parameters: Map[String, String], conf: Configuration) = { + this(CaseInsensitiveMap(parameters), conf) + } + + /** + * Optional schema provided by an user in JSON format. + */ + val schema: Option[String] = parameters.get("avroSchema") + + /** + * Top level record name in write result, which is required in Avro spec. + * See https://avro.apache.org/docs/1.8.2/spec.html#schema_record . + * Default value is "topLevelRecord" + */ + val recordName: String = parameters.getOrElse("recordName", "topLevelRecord") + + /** + * Record namespace in write result. Default value is "". + * See Avro spec for details: https://avro.apache.org/docs/1.8.2/spec.html#schema_record . + */ + val recordNamespace: String = parameters.getOrElse("recordNamespace", "") + + /** + * The `ignoreExtension` option controls ignoring of files without `.avro` extensions in read. + * If the option is enabled, all files (with and without `.avro` extension) are loaded. + * If the option is not set, the Hadoop's config `avro.mapred.ignore.inputs.without.extension` + * is taken into account. If the former one is not set too, file extensions are ignored. + */ + val ignoreExtension: Boolean = { + val ignoreFilesWithoutExtensionByDefault = false + val ignoreFilesWithoutExtension = conf.getBoolean( + AvroFileFormat.IgnoreFilesWithoutExtensionProperty, + ignoreFilesWithoutExtensionByDefault) + + parameters + .get("ignoreExtension") + .map(_.toBoolean) + .getOrElse(!ignoreFilesWithoutExtension) + } + + /** + * The `compression` option allows to specify a compression codec used in write. + * Currently supported codecs are `uncompressed`, `snappy`, `deflate`, `bzip2` and `xz`. + * If the option is not set, the `spark.sql.avro.compression.codec` config is taken into + * account. If the former one is not set too, the `snappy` codec is used by default. + */ + val compression: String = { + parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala new file mode 100644 index 0000000000000..06507115f5ed8 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{IOException, OutputStream} + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericRecord +import org.apache.avro.mapred.AvroKey +import org.apache.avro.mapreduce.AvroKeyOutputFormat +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types._ + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[avro] class AvroOutputWriter( + path: String, + context: TaskAttemptContext, + schema: StructType, + avroSchema: Schema) extends OutputWriter { + + // The input rows will never be null. + private lazy val serializer = new AvroSerializer(schema, avroSchema, nullable = false) + + /** + * Overrides the couple of methods responsible for generating the output streams / files so + * that the data can be correctly partitioned + */ + private val recordWriter: RecordWriter[AvroKey[GenericRecord], NullWritable] = + new AvroKeyOutputFormat[GenericRecord]() { + + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + + @throws(classOf[IOException]) + override def getAvroFileOutputStream(c: TaskAttemptContext): OutputStream = { + val path = getDefaultWorkFile(context, ".avro") + path.getFileSystem(context.getConfiguration).create(path) + } + + }.getRecordWriter(context) + + override def write(row: InternalRow): Unit = { + val key = new AvroKey(serializer.serialize(row).asInstanceOf[GenericRecord]) + recordWriter.write(key, NullWritable.get()) + } + + override def close(): Unit = recordWriter.close(context) +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala new file mode 100644 index 0000000000000..116020ed5c433 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.StructType + +/** + * A factory that produces [[AvroOutputWriter]]. + * @param catalystSchema Catalyst schema of input data. + * @param avroSchemaAsJsonString Avro schema of output result, in JSON string format. + */ +private[avro] class AvroOutputWriterFactory( + catalystSchema: StructType, + avroSchemaAsJsonString: String) extends OutputWriterFactory { + + private lazy val avroSchema = new Schema.Parser().parse(avroSchemaAsJsonString) + + override def getFileExtension(context: TaskAttemptContext): String = ".avro" + + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new AvroOutputWriter(path, context, catalystSchema, avroSchema) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala new file mode 100644 index 0000000000000..e902b4c77eaad --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} +import org.apache.avro.Schema +import org.apache.avro.Schema.Type +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} +import org.apache.avro.generic.GenericData.Record +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in avro format. + */ +class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) { + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val actualAvroType = resolveNullableType(rootAvroType, nullable) + val baseConverter = rootCatalystType match { + case st: StructType => + newStructConverter(st, actualAvroType).asInstanceOf[Any => Any] + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val converter = newConverter(rootCatalystType, actualAvroType) + (data: Any) => + tmpRow.update(0, data) + converter.apply(tmpRow, 0) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private lazy val decimalConversions = new DecimalConversion() + + private def newConverter(catalystType: DataType, avroType: Schema): Converter = { + (catalystType, avroType.getType) match { + case (NullType, NULL) => + (getter, ordinal) => null + case (BooleanType, BOOLEAN) => + (getter, ordinal) => getter.getBoolean(ordinal) + case (ByteType, INT) => + (getter, ordinal) => getter.getByte(ordinal).toInt + case (ShortType, INT) => + (getter, ordinal) => getter.getShort(ordinal).toInt + case (IntegerType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + case (LongType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + case (FloatType, FLOAT) => + (getter, ordinal) => getter.getFloat(ordinal) + case (DoubleType, DOUBLE) => + (getter, ordinal) => getter.getDouble(ordinal) + case (d: DecimalType, FIXED) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (d: DecimalType, BYTES) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (StringType, ENUM) => + val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException( + "Cannot write \"" + data + "\" since it's not defined in enum \"" + + enumSymbols.mkString("\", \"") + "\"") + } + new EnumSymbol(avroType, data) + + case (StringType, STRING) => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + + case (BinaryType, FIXED) => + val size = avroType.getFixedSize() + (getter, ordinal) => + val data: Array[Byte] = getter.getBinary(ordinal) + if (data.length != size) { + throw new IncompatibleSchemaException( + s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " + + "binary data into FIXED Type with size of " + + s"$size ${if (size > 1) "bytes" else "byte"}") + } + new Fixed(avroType, data) + + case (BinaryType, BYTES) => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + + case (DateType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + + case (TimestampType, LONG) => avroType.getLogicalType match { + case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000 + case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal) + // For backward compatibility, if the Avro type is Long and it is not logical type, + // output the timestamp value as with millisecond precision. + case null => (getter, ordinal) => getter.getLong(ordinal) / 1000 + case other => throw new IncompatibleSchemaException( + s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}") + } + + case (ArrayType(et, containsNull), ARRAY) => + val elementConverter = newConverter( + et, resolveNullableType(avroType.getElementType, containsNull)) + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val len = arrayData.numElements() + val result = new Array[Any](len) + var i = 0 + while (i < len) { + if (containsNull && arrayData.isNullAt(i)) { + result(i) = null + } else { + result(i) = elementConverter(arrayData, i) + } + i += 1 + } + // avro writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + java.util.Arrays.asList(result: _*) + } + + case (st: StructType, RECORD) => + val structConverter = newStructConverter(st, avroType) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType => + val valueConverter = newConverter( + vt, resolveNullableType(avroType.getValueType, valueContainsNull)) + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val len = mapData.numElements() + val result = new java.util.HashMap[String, Any](len) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < len) { + val key = keyArray.getUTF8String(i).toString + if (valueContainsNull && valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case other => + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystType to " + + s"Avro type $avroType.") + } + } + + private def newStructConverter( + catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { + if (avroStruct.getType != RECORD || avroStruct.getFields.size() != catalystStruct.length) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " + + s"Avro type $avroStruct.") + } + val fieldConverters = catalystStruct.zip(avroStruct.getFields.asScala).map { + case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable)) + } + val numFields = catalystStruct.length + (row: InternalRow) => + val result = new Record(avroStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + result.put(i, null) + } else { + result.put(i, fieldConverters(i).apply(row, i)) + } + i += 1 + } + result + } + + private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { + if (nullable && avroType.getType != NULL) { + // avro uses union to represent nullable type. + val fields = avroType.getTypes.asScala + assert(fields.length == 2) + val actualType = fields.filter(_.getType != Type.NULL) + assert(actualType.length == 1) + actualType.head + } else { + avroType + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala new file mode 100644 index 0000000000000..141ff3782adfb --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.ByteArrayOutputStream + +import org.apache.avro.generic.GenericDatumWriter +import org.apache.avro.io.{BinaryEncoder, EncoderFactory} + +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{BinaryType, DataType} + +case class CatalystDataToAvro(child: Expression) extends UnaryExpression { + + override def dataType: DataType = BinaryType + + @transient private lazy val avroType = + SchemaConverters.toAvroType(child.dataType, child.nullable) + + @transient private lazy val serializer = + new AvroSerializer(child.dataType, avroType, child.nullable) + + @transient private lazy val writer = + new GenericDatumWriter[Any](avroType) + + @transient private var encoder: BinaryEncoder = _ + + @transient private lazy val out = new ByteArrayOutputStream + + override def nullSafeEval(input: Any): Any = { + out.reset() + encoder = EncoderFactory.get().directBinaryEncoder(out, encoder) + val avroData = serializer.serialize(input) + writer.write(avroData, encoder) + encoder.flush() + out.toByteArray + } + + override def simpleString: String = { + s"to_avro(${child.sql}, ${child.dataType.simpleString})" + } + + override def sql: String = { + s"to_avro(${child.sql}, ${child.dataType.catalogString})" + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(byte[]) $expr.nullSafeEval($input)") + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala new file mode 100644 index 0000000000000..bd1576587d7fa --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type._ + +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.Decimal.{maxPrecisionForBytes, minBytesForPrecision} + +/** + * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice + * versa. + */ +object SchemaConverters { + private lazy val uuidGenerator = RandomUUIDGenerator(new Random().nextLong()) + + private lazy val nullSchema = Schema.create(Schema.Type.NULL) + + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * This function takes an avro schema and returns a sql schema. + */ + def toSqlType(avroSchema: Schema): SchemaType = { + avroSchema.getType match { + case INT => avroSchema.getLogicalType match { + case _: Date => SchemaType(DateType, nullable = false) + case _ => SchemaType(IntegerType, nullable = false) + } + case STRING => SchemaType(StringType, nullable = false) + case BOOLEAN => SchemaType(BooleanType, nullable = false) + case BYTES | FIXED => avroSchema.getLogicalType match { + // For FIXED type, if the precision requires more bytes than fixed size, the logical + // type will be null, which is handled by Avro library. + case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false) + case _ => SchemaType(BinaryType, nullable = false) + } + + case DOUBLE => SchemaType(DoubleType, nullable = false) + case FLOAT => SchemaType(FloatType, nullable = false) + case LONG => avroSchema.getLogicalType match { + case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false) + case _ => SchemaType(LongType, nullable = false) + } + + case ENUM => SchemaType(StringType, nullable = false) + + case RECORD => + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlType(f.schema()) + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + + SchemaType(StructType(fields), nullable = false) + + case ARRAY => + val schemaType = toSqlType(avroSchema.getElementType) + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + + case MAP => + val schemaType = toSqlType(avroSchema.getValueType) + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + + case UNION => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + toSqlType(remainingUnionTypes.head).copy(nullable = true) + } else { + toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + toSqlType(avroSchema.getTypes.get(0)) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + SchemaType(LongType, nullable = false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + SchemaType(DoubleType, nullable = false) + case _ => + // Convert complex unions to struct types where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val fields = avroSchema.getTypes.asScala.zipWithIndex.map { + case (s, i) => + val schemaType = toSqlType(s) + // All fields are nullable because only one of them is set at a time + StructField(s"member$i", schemaType.dataType, nullable = true) + } + + SchemaType(StructType(fields), nullable = false) + } + + case other => throw new IncompatibleSchemaException(s"Unsupported type $other") + } + } + + def toAvroType( + catalystType: DataType, + nullable: Boolean = false, + recordName: String = "topLevelRecord", + nameSpace: String = "") + : Schema = { + val builder = SchemaBuilder.builder() + + val schema = catalystType match { + case BooleanType => builder.booleanType() + case ByteType | ShortType | IntegerType => builder.intType() + case LongType => builder.longType() + case DateType => + LogicalTypes.date().addToSchema(builder.intType()) + case TimestampType => + LogicalTypes.timestampMicros().addToSchema(builder.longType()) + + case FloatType => builder.floatType() + case DoubleType => builder.doubleType() + case StringType => builder.stringType() + case d: DecimalType => + val avroType = LogicalTypes.decimal(d.precision, d.scale) + val fixedSize = minBytesForPrecision(d.precision) + // Need to avoid naming conflict for the fixed fields + val name = nameSpace match { + case "" => s"$recordName.fixed" + case _ => s"$nameSpace.$recordName.fixed" + } + avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize)) + + case BinaryType => builder.bytesType() + case ArrayType(et, containsNull) => + builder.array() + .items(toAvroType(et, containsNull, recordName, nameSpace)) + case MapType(StringType, vt, valueContainsNull) => + builder.map() + .values(toAvroType(vt, valueContainsNull, recordName, nameSpace)) + case st: StructType => + val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName + val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() + st.foreach { f => + val fieldAvroType = + toAvroType(f.dataType, f.nullable, f.name, childNameSpace) + fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() + } + fieldsAssembler.endRecord() + + // This should never happen. + case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") + } + if (nullable) { + Schema.createUnion(schema, nullSchema) + } else { + schema + } + } +} + +class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala new file mode 100755 index 0000000000000..97f9427f96c55 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental + +package object avro { + /** + * Converts a binary column of avro format into its corresponding catalyst value. The specified + * schema must match the read data, otherwise the behavior is undefined: it may fail or return + * arbitrary result. + * + * @param data the binary column. + * @param jsonFormatSchema the avro schema in JSON string format. + * + * @since 2.4.0 + */ + @Experimental + def from_avro(data: Column, jsonFormatSchema: String): Column = { + new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema)) + } + + /** + * Converts a column into binary of avro format. + * + * @param data the data column. + * + * @since 2.4.0 + */ + @Experimental + def to_avro(data: Column): Column = { + new Column(CatalystDataToAvro(data.expr)) + } +} diff --git a/external/avro/src/test/resources/episodes.avro b/external/avro/src/test/resources/episodes.avro new file mode 100644 index 0000000000000..58a028ce19e6a Binary files /dev/null and b/external/avro/src/test/resources/episodes.avro differ diff --git a/external/avro/src/test/resources/log4j.properties b/external/avro/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..75e3b53a093f6 --- /dev/null +++ b/external/avro/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN + diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro new file mode 100755 index 0000000000000..fece892444979 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro new file mode 100755 index 0000000000000..1ca623a07dcf3 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro new file mode 100755 index 0000000000000..a12e9459e7461 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro new file mode 100755 index 0000000000000..60c095691d5d5 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro new file mode 100755 index 0000000000000..af56dfc8083dc Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro new file mode 100755 index 0000000000000..87d78447526f9 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro new file mode 100755 index 0000000000000..c326fc434bf18 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro new file mode 100755 index 0000000000000..279f36c317eb8 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro new file mode 100755 index 0000000000000..8d70f5d1274d4 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro new file mode 100755 index 0000000000000..6839d7217e492 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro new file mode 100755 index 0000000000000..aedc7f7e0e61c Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro differ diff --git a/external/avro/src/test/resources/test.avro b/external/avro/src/test/resources/test.avro new file mode 100644 index 0000000000000..6425e2107e304 Binary files /dev/null and b/external/avro/src/test/resources/test.avro differ diff --git a/external/avro/src/test/resources/test.avsc b/external/avro/src/test/resources/test.avsc new file mode 100644 index 0000000000000..d7119a01f6aa0 --- /dev/null +++ b/external/avro/src/test/resources/test.avsc @@ -0,0 +1,53 @@ +{ + "type" : "record", + "name" : "test_schema", + "fields" : [{ + "name" : "string", + "type" : "string", + "doc" : "Meaningless string of characters" + }, { + "name" : "simple_map", + "type" : {"type": "map", "values": "int"} + }, { + "name" : "complex_map", + "type" : {"type": "map", "values": {"type": "map", "values": "string"}} + }, { + "name" : "union_string_null", + "type" : ["null", "string"] + }, { + "name" : "union_int_long_null", + "type" : ["int", "long", "null"] + }, { + "name" : "union_float_double", + "type" : ["float", "double"] + }, { + "name": "fixed3", + "type": {"type": "fixed", "size": 3, "name": "fixed3"} + }, { + "name": "fixed2", + "type": {"type": "fixed", "size": 2, "name": "fixed2"} + }, { + "name": "enum", + "type": { "type": "enum", + "name": "Suit", + "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + } + }, { + "name": "record", + "type": { + "type": "record", + "name": "record", + "aliases": ["RecordAlias"], + "fields" : [{ + "name": "value_field", + "type": "string" + }] + } + }, { + "name": "array_of_boolean", + "type": {"type": "array", "items": "boolean"} + }, { + "name": "bytes", + "type": "bytes" + }] +} diff --git a/external/avro/src/test/resources/test.json b/external/avro/src/test/resources/test.json new file mode 100644 index 0000000000000..780189a92b378 --- /dev/null +++ b/external/avro/src/test/resources/test.json @@ -0,0 +1,42 @@ +{ + "string": "OMG SPARK IS AWESOME", + "simple_map": {"abc": 1, "bcd": 7}, + "complex_map": {"key": {"a": "b", "c": "d"}}, + "union_string_null": {"string": "abc"}, + "union_int_long_null": {"int": 1}, + "union_float_double": {"float": 3.1415926535}, + "fixed3":"\u0002\u0003\u0004", + "fixed2":"\u0011\u0012", + "enum": "SPADES", + "record": {"value_field": "Two things are infinite: the universe and human stupidity; and I'm not sure about universe."}, + "array_of_boolean": [true, false, false], + "bytes": "\u0041\u0042\u0043" +} +{ + "string": "Terran is IMBA!", + "simple_map": {"mmm": 0, "qqq": 66}, + "complex_map": {"key": {"1": "2", "3": "4"}}, + "union_string_null": {"string": "123"}, + "union_int_long_null": {"long": 66}, + "union_float_double": {"double": 6.6666666666666}, + "fixed3":"\u0007\u0007\u0007", + "fixed2":"\u0001\u0002", + "enum": "CLUBS", + "record": {"value_field": "Life did not intend to make us perfect. Whoever is perfect belongs in a museum."}, + "array_of_boolean": [], + "bytes": "" +} +{ + "string": "The cake is a LIE!", + "simple_map": {}, + "complex_map": {"key": {}}, + "union_string_null": {"null": null}, + "union_int_long_null": {"null": null}, + "union_float_double": {"double": 0}, + "fixed3":"\u0011\u0022\u0009", + "fixed2":"\u0010\u0090", + "enum": "DIAMONDS", + "record": {"value_field": "TEST_STR123"}, + "array_of_boolean": [false], + "bytes": "\u0053" +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala new file mode 100644 index 0000000000000..8334cca6cd8f1 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ + +class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def roundTripTest(data: Literal): Unit = { + val avroType = SchemaConverters.toAvroType(data.dataType, data.nullable) + checkResult(data, avroType.toString, data.eval()) + } + + private def checkResult(data: Literal, schema: String, expected: Any): Unit = { + checkEvaluation( + AvroDataToCatalyst(CatalystDataToAvro(data), schema), + prepareExpectedResult(expected)) + } + + private def assertFail(data: Literal, schema: String): Unit = { + intercept[java.io.EOFException] { + AvroDataToCatalyst(CatalystDataToAvro(data), schema).eval() + } + } + + private val testingTypes = Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(8, 0), // 32 bits decimal without fraction + DecimalType(8, 4), // 32 bits decimal + DecimalType(16, 0), // 64 bits decimal without fraction + DecimalType(16, 11), // 64 bits decimal + DecimalType(38, 0), + DecimalType(38, 38), + StringType, + BinaryType) + + protected def prepareExpectedResult(expected: Any): Any = expected match { + // Spark byte and short both map to avro int + case b: Byte => b.toInt + case s: Short => s.toInt + case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) + case map: MapData => + val keys = new GenericArrayData( + map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + val values = new GenericArrayData( + map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + new ArrayBasedMapData(keys, values) + case other => other + } + + testingTypes.foreach { dt => + val seed = scala.util.Random.nextLong() + test(s"single $dt with seed $seed") { + val rand = new scala.util.Random(seed) + val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val converter = CatalystTypeConverters.createToCatalystConverter(dt) + val input = Literal.create(converter(data), dt) + roundTripTest(input) + } + } + + for (_ <- 1 to 5) { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = RandomDataGenerator.randomSchema(rand, 5, testingTypes) + test(s"flat schema ${schema.catalogString} with seed $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + + for (_ <- 1 to 5) { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes) + test(s"nested schema ${schema.catalogString} with seed $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + + test("read int as string") { + val data = Literal(1) + val avroTypeJson = + s""" + |{ + | "type": "string", + | "name": "my_string" + |} + """.stripMargin + + // When read int as string, avro reader is not able to parse the binary and fail. + assertFail(data, avroTypeJson) + } + + test("read string as int") { + val data = Literal("abc") + val avroTypeJson = + s""" + |{ + | "type": "int", + | "name": "my_int" + |} + """.stripMargin + + // When read string data as int, avro reader is not able to find the type mismatch and read + // the string length as int value. + checkResult(data, avroTypeJson, 3) + } + + test("read float as double") { + val data = Literal(1.23f) + val avroTypeJson = + s""" + |{ + | "type": "double", + | "name": "my_double" + |} + """.stripMargin + + // When read float data as double, avro reader fails(trying to read 8 bytes while the data have + // only 4 bytes). + assertFail(data, avroTypeJson) + } + + test("read double as float") { + val data = Literal(1.23) + val avroTypeJson = + s""" + |{ + | "type": "float", + | "name": "my_float" + |} + """.stripMargin + + // avro reader reads the first 4 bytes of a double as a float, the result is totally undefined. + checkResult(data, avroTypeJson, 5.848603E35f) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala new file mode 100644 index 0000000000000..90a4cd6ccf9dd --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.test.SharedSQLContext + +class AvroFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("roundtrip in to_avro and from_avro - int and string") { + val df = spark.range(10).select('id, 'id.cast("string").as("str")) + + val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b")) + val avroTypeLong = s""" + |{ + | "type": "int", + | "name": "id" + |} + """.stripMargin + val avroTypeStr = s""" + |{ + | "type": "string", + | "name": "str" + |} + """.stripMargin + checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df) + } + + test("roundtrip in to_avro and from_avro - struct") { + val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) + val avroStructDF = df.select(to_avro('struct).as("avro")) + val avroTypeStruct = s""" + |{ + | "type": "record", + | "name": "struct", + | "fields": [ + | {"name": "col1", "type": "long"}, + | {"name": "col2", "type": "string"} + | ] + |} + """.stripMargin + checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) + } + + test("roundtrip in to_avro and from_avro - array with null") { + val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") + val avroTypeArrStruct = s""" + |[ { + | "type" : "array", + | "items" : [ { + | "type" : "record", + | "name" : "x", + | "fields" : [ { + | "name" : "y", + | "type" : "int" + | } ] + | }, "null" ] + |}, "null" ] + """.stripMargin + val readBackOne = dfOne.select(to_avro($"array").as("avro")) + .select(from_avro($"avro", avroTypeArrStruct).as("array")) + checkAnswer(dfOne, readBackOne) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala new file mode 100644 index 0000000000000..79ba2871c2264 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.avro + +import java.io.File +import java.sql.Timestamp + +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types.{StructField, StructType, TimestampType} + +class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import testImplicits._ + + val dateSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "date", "type": {"type": "int", "logicalType": "date"}} + ] + } + """ + + val dateInputData = Seq(7, 365, 0) + + def dateFile(path: String): String = { + val schema = new Schema.Parser().parse(dateSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val result = s"$path/test.avro" + dataFileWriter.create(schema, new File(result)) + + dateInputData.foreach { x => + val record = new GenericData.Record(schema) + record.put("date", x) + dataFileWriter.append(record) + } + dataFileWriter.flush() + dataFileWriter.close() + result + } + + test("Logical type: date") { + withTempDir { dir => + val expected = dateInputData.map(t => Row(DateTimeUtils.toJavaDate(t))) + val dateAvro = dateFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(dateAvro) + + checkAnswer(df, expected) + + checkAnswer(spark.read.format("avro").option("avroSchema", dateSchema).load(dateAvro), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + val timestampSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, + {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, + {"name": "long", "type": "long"} + ] + } + """ + + val timestampInputData = Seq((1000L, 2000L, 3000L), (666000L, 999000L, 777000L)) + + def timestampFile(path: String): String = { + val schema = new Schema.Parser().parse(timestampSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val result = s"$path/test.avro" + dataFileWriter.create(schema, new File(result)) + + timestampInputData.foreach { t => + val record = new GenericData.Record(schema) + record.put("timestamp_millis", t._1) + // For microsecond precision, we multiple the value by 1000 to match the expected answer as + // timestamp with millisecond precision. + record.put("timestamp_micros", t._2 * 1000) + record.put("long", t._3) + dataFileWriter.append(record) + } + dataFileWriter.flush() + dataFileWriter.close() + result + } + + test("Logical type: timestamp_millis") { + withTempDir { dir => + val expected = timestampInputData.map(t => Row(new Timestamp(t._1))) + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis) + + checkAnswer(df, expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: timestamp_micros") { + withTempDir { dir => + val expected = timestampInputData.map(t => Row(new Timestamp(t._2))) + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros) + + checkAnswer(df, expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: user specified output schema with different timestamp types") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = + spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros) + + val expected = timestampInputData.map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) + + val userSpecifiedTimestampSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "timestamp_millis", + "type": [{"type": "long","logicalType": "timestamp-micros"}, "null"]}, + {"name": "timestamp_micros", + "type": [{"type": "long","logicalType": "timestamp-millis"}, "null"]} + ] + } + """ + + withTempPath { path => + df.write + .format("avro") + .option("avroSchema", userSpecifiedTimestampSchema) + .save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Read Long type as Timestamp") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val schema = StructType(StructField("long", TimestampType, true) :: Nil) + val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long) + + val expected = timestampInputData.map(t => Row(new Timestamp(t._3))) + + checkAnswer(df, expected) + } + } + + test("Logical type: user specified read schema") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val expected = timestampInputData + .map(t => Row(new Timestamp(t._1), new Timestamp(t._2), t._3)) + + val df = spark.read.format("avro").option("avroSchema", timestampSchema).load(timestampAvro) + checkAnswer(df, expected) + } + } + + val decimalInputData = Seq("1.23", "4.56", "78.90", "-1", "-2.31") + + def decimalSchemaAndFile(path: String): (String, String) = { + val precision = 4 + val scale = 2 + val bytesFieldName = "bytes" + val bytesSchema = s"""{ + "type":"bytes", + "logicalType":"decimal", + "precision":$precision, + "scale":$scale + } + """ + + val fixedFieldName = "fixed" + val fixedSchema = s"""{ + "type":"fixed", + "size":5, + "logicalType":"decimal", + "precision":$precision, + "scale":$scale, + "name":"foo" + } + """ + val avroSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "$bytesFieldName", "type": $bytesSchema}, + {"name": "$fixedFieldName", "type": $fixedSchema} + ] + } + """ + val schema = new Schema.Parser().parse(avroSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val decimalConversion = new DecimalConversion + val avroFile = s"$path/test.avro" + dataFileWriter.create(schema, new File(avroFile)) + val logicalType = LogicalTypes.decimal(precision, scale) + + decimalInputData.map { x => + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal(x).setScale(scale) + val bytes = + decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType) + avroRec.put(bytesFieldName, bytes) + val fixed = + decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType) + avroRec.put(fixedFieldName, fixed) + dataFileWriter.append(avroRec) + } + dataFileWriter.flush() + dataFileWriter.close() + + (avroSchema, avroFile) + } + + test("Logical type: Decimal") { + withTempDir { dir => + val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath) + val expected = + decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + checkAnswer(df, expected) + checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: write Decimal with BYTES type") { + val specifiedSchema = """ + { + "type" : "record", + "name" : "topLevelRecord", + "namespace" : "topLevelRecord", + "fields" : [ { + "name" : "bytes", + "type" : [ { + "type" : "bytes", + "namespace" : "topLevelRecord.bytes", + "logicalType" : "decimal", + "precision" : 4, + "scale" : 2 + }, "null" ] + }, { + "name" : "fixed", + "type" : [ { + "type" : "bytes", + "logicalType" : "decimal", + "precision" : 4, + "scale" : 2 + }, "null" ] + } ] + } + """ + withTempDir { dir => + val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath) + assert(specifiedSchema != avroSchema) + val expected = + decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + + withTempPath { path => + df.write.format("avro").option("avroSchema", specifiedSchema).save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: Decimal with too large precision") { + withTempDir { dir => + val schema = new Schema.Parser().parse("""{ + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [{ + "name": "decimal", + "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2} + }] + }""") + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678") + val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38)) + avroRec.put("decimal", bytes) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val msg = intercept[SparkException] { + spark.read.format("avro").load(s"$dir.avro").collect() + }.getCause.getMessage + assert(msg.contains("Unscaled value too large for precision")) + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala new file mode 100644 index 0000000000000..9ad4388414eaa --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -0,0 +1,1269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URL +import java.nio.file.{Files, Paths} +import java.sql.{Date, Timestamp} +import java.util.{TimeZone, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.Schema.Type._ +import org.apache.avro.file.{DataFileReader, DataFileWriter} +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} +import org.apache.commons.io.FileUtils + +import org.apache.spark.SparkException +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ + +class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import testImplicits._ + + val episodesAvro = testFile("episodes.avro") + val testAvro = testFile("test.avro") + + override protected def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + } + + def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { + val originalEntries = spark.read.format("avro").load(testAvro).collect() + val newEntries = spark.read.format("avro").load(newFile) + checkAnswer(newEntries, originalEntries) + } + + def checkAvroSchemaEquals(avroSchema: String, expectedAvroSchema: String): Unit = { + assert(new Schema.Parser().parse(avroSchema) == + new Schema.Parser().parse(expectedAvroSchema)) + } + + def getAvroSchemaStringFromFiles(filePath: String): String = { + new DataFileReader({ + val file = new File(filePath) + if (file.isFile) { + file + } else { + file.listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + .head + } + }, new GenericDatumReader[Any]()).getSchema.toString(false) + } + + test("resolve avro data source") { + val databricksAvro = "com.databricks.spark.avro" + // By default the backward compatibility for com.databricks.spark.avro is enabled. + Seq("avro", "org.apache.spark.sql.avro.AvroFileFormat", databricksAvro).foreach { provider => + assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) === + classOf[org.apache.spark.sql.avro.AvroFileFormat]) + } + + withSQLConf(SQLConf.LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED.key -> "false") { + val message = intercept[AnalysisException] { + DataSource.lookupDataSource(databricksAvro, spark.sessionState.conf) + }.getMessage + assert(message.contains(s"Failed to find data source: $databricksAvro")) + } + } + + test("reading from multiple paths") { + val df = spark.read.format("avro").load(episodesAvro, episodesAvro) + assert(df.count == 16) + } + + test("reading and writing partitioned data") { + val df = spark.read.format("avro").load(episodesAvro) + val fields = List("title", "air_date", "doctor") + for (field <- fields) { + withTempPath { dir => + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.partitionBy(field).format("avro").save(outputDir) + val input = spark.read.format("avro").load(outputDir) + // makes sure that no fields got dropped. + // We convert Rows to Seqs in order to work around SPARK-10325 + assert(input.select(field).collect().map(_.toSeq).toSet === + df.select(field).collect().map(_.toSeq).toSet) + } + } + } + + test("request no fields") { + val df = spark.read.format("avro").load(episodesAvro) + df.createOrReplaceTempView("avro_table") + assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) + } + + test("convert formats") { + withTempPath { dir => + val df = spark.read.format("avro").load(episodesAvro) + df.write.parquet(dir.getCanonicalPath) + assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) + } + } + + test("rearrange internal schema") { + withTempPath { dir => + val df = spark.read.format("avro").load(episodesAvro) + df.select("doctor", "title").write.format("avro").save(dir.getCanonicalPath) + } + } + + test("test NULL avro type") { + withTempPath { dir => + val fields = + Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + avroRec.put("null", null) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + intercept[IncompatibleSchemaException] { + spark.read.format("avro").load(s"$dir.avro") + } + } + } + + test("union(int, long) is read as long") { + withTempPath { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toLong) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true))) + assert(df.collect().toSet == Set(Row(1L), Row(2L))) + } + } + + test("union(float, double) is read as double") { + withTempPath { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2.toDouble) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble))) + } + } + + test("union(float, double, null) is read as nullable double") { + withTempPath { dir => + val avroSchema: Schema = { + val union = Schema.createUnion( + List(Schema.create(Type.FLOAT), + Schema.create(Type.DOUBLE), + Schema.create(Type.NULL) + ).asJava + ) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", null) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(null))) + } + } + + test("Union of a single type") { + withTempPath { dir => + val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + + avroRec.put("field1", 8) + + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.first() == Row(8)) + } + } + + test("Complex Union Type") { + withTempPath { dir => + val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) + val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) + val complexUnionType = Schema.createUnion( + List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) + val fields = Seq( + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) + ).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val field1 = 1234 + val field2 = "Hope that was not load bearing" + val field3 = Array[Byte](1, 2, 3, 4) + val field4 = "e2" + avroRec.put("field1", field1) + avroRec.put("field2", field2) + avroRec.put("field3", new Fixed(fixedSchema, field3)) + avroRec.put("field4", new EnumSymbol(enumSchema, field4)) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.sqlContext.read.format("avro").load(s"$dir.avro") + assertResult(field1)(df.selectExpr("field1.member0").first().get(0)) + assertResult(field2)(df.selectExpr("field2.member1").first().get(0)) + assertResult(field3)(df.selectExpr("field3.member2").first().get(0)) + assertResult(field4)(df.selectExpr("field4.member3").first().get(0)) + } + } + + test("Lots of nulls") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("binary", BinaryType, true), + StructField("timestamp", TimestampType, true), + StructField("array", ArrayType(ShortType), true), + StructField("map", MapType(StringType, StringType), true), + StructField("struct", StructType(Seq(StructField("int", IntegerType, true)))))) + val rdd = spark.sparkContext.parallelize(Seq[Row]( + Row(null, new Timestamp(1), Array[Short](1, 2, 3), null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null))) + val df = spark.createDataFrame(rdd, schema) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + } + } + + test("Struct field type") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("short", ShortType, true), + StructField("byte", ByteType, true), + StructField("boolean", BooleanType, true) + )) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, 1.toShort, 1.toByte, true), + Row(2f, 2.toShort, 2.toByte, true), + Row(3f, 3.toShort, 3.toByte, true) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + } + } + + test("Date field type") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("date", DateType, true) + )) + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, null), + Row(2f, new Date(1451948400000L)), + Row(3f, new Date(1460066400500L)) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + checkAnswer( + spark.read.format("avro").load(dir.toString).select("date"), + Seq(Row(null), Row(new Date(1451865600000L)), Row(new Date(1459987200000L)))) + } + } + + test("Array data types") { + withTempPath { dir => + val testSchema = StructType(Seq( + StructField("byte_array", ArrayType(ByteType), true), + StructField("short_array", ArrayType(ShortType), true), + StructField("float_array", ArrayType(FloatType), true), + StructField("bool_array", ArrayType(BooleanType), true), + StructField("long_array", ArrayType(LongType), true), + StructField("double_array", ArrayType(DoubleType), true), + StructField("decimal_array", ArrayType(DecimalType(10, 0)), true), + StructField("bin_array", ArrayType(BinaryType), true), + StructField("timestamp_array", ArrayType(TimestampType), true), + StructField("array_array", ArrayType(ArrayType(StringType), true), true), + StructField("struct_array", ArrayType( + StructType(Seq(StructField("name", StringType, true))))))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + + val rdd = spark.sparkContext.parallelize(Seq( + Row(arrayOfByte, Array[Short](1, 2, 3, 4), Array[Float](1f, 2f, 3f, 4f), + Array[Boolean](true, false, true, false), Array[Long](1L, 2L), Array[Double](1.0, 2.0), + Array[BigDecimal](BigDecimal.valueOf(3)), Array[Array[Byte]](arrayOfByte, arrayOfByte), + Array[Timestamp](new Timestamp(0)), + Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")), + Array[Row](Row("Bobby G. can't swim"))))) + val df = spark.createDataFrame(rdd, testSchema) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + } + } + + test("write with compression - sql configs") { + withTempPath { dir => + val uncompressDir = s"$dir/uncompress" + val bzip2Dir = s"$dir/bzip2" + val xzDir = s"$dir/xz" + val deflateDir = s"$dir/deflate" + val snappyDir = s"$dir/snappy" + + val df = spark.read.format("avro").load(testAvro) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "uncompressed") + df.write.format("avro").save(uncompressDir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "bzip2") + df.write.format("avro").save(bzip2Dir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "xz") + df.write.format("avro").save(xzDir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "deflate") + spark.conf.set(SQLConf.AVRO_DEFLATE_LEVEL.key, "9") + df.write.format("avro").save(deflateDir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "snappy") + df.write.format("avro").save(snappyDir) + + val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) + val bzip2Size = FileUtils.sizeOfDirectory(new File(bzip2Dir)) + val xzSize = FileUtils.sizeOfDirectory(new File(xzDir)) + val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) + val snappySize = FileUtils.sizeOfDirectory(new File(snappyDir)) + + assert(uncompressSize > deflateSize) + assert(snappySize > deflateSize) + assert(snappySize > bzip2Size) + assert(bzip2Size > xzSize) + } + } + + test("dsl test") { + val results = spark.read.format("avro").load(episodesAvro).select("title").collect() + assert(results.length === 8) + } + + test("old avro data source name works") { + val results = + spark.read.format("com.databricks.spark.avro") + .load(episodesAvro).select("title").collect() + assert(results.length === 8) + } + + test("support of various data types") { + // This test uses data from test.avro. You can see the data and the schema of this file in + // test.json and test.avsc + val all = spark.read.format("avro").load(testAvro).collect() + assert(all.length == 3) + + val str = spark.read.format("avro").load(testAvro).select("string").collect() + assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) + + val simple_map = spark.read.format("avro").load(testAvro).select("simple_map").collect() + assert(simple_map(0)(0).getClass.toString.contains("Map")) + assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) + + val union0 = spark.read.format("avro").load(testAvro).select("union_string_null").collect() + assert(union0.map(_(0)).toSet == Set("abc", "123", null)) + + val union1 = spark.read.format("avro").load(testAvro).select("union_int_long_null").collect() + assert(union1.map(_(0)).toSet == Set(66, 1, null)) + + val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect() + assert( + union2 + .map(x => new java.lang.Double(x(0).toString)) + .exists(p => Math.abs(p - Math.PI) < 0.001)) + + val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect() + assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) + + val enum = spark.read.format("avro").load(testAvro).select("enum").collect() + assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) + + val record = spark.read.format("avro").load(testAvro).select("record").collect() + assert(record(0)(0).getClass.toString.contains("Row")) + assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) + + val array_of_boolean = + spark.read.format("avro").load(testAvro).select("array_of_boolean").collect() + assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) + + val bytes = spark.read.format("avro").load(testAvro).select("bytes").collect() + assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) + } + + test("sql test") { + spark.sql( + s""" + |CREATE TEMPORARY VIEW avroTable + |USING avro + |OPTIONS (path "${episodesAvro}") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM avroTable").collect().length === 8) + } + + test("conversion to avro and back") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + withTempPath { dir => + val avroDir = s"$dir/avro" + spark.read.format("avro").load(testAvro).write.format("avro").save(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) + } + } + + test("conversion to avro and back with namespace") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + withTempPath { tempDir => + val name = "AvroTest" + val namespace = "org.apache.spark.avro" + val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) + + val avroDir = tempDir + "/namedAvro" + spark.read.format("avro").load(testAvro) + .write.options(parameters).format("avro").save(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) + + // Look at raw file and make sure has namespace info + val rawSaved = spark.sparkContext.textFile(avroDir) + val schema = rawSaved.collect().mkString("") + assert(schema.contains(name)) + assert(schema.contains(namespace)) + } + } + + test("converting some specific sparkSQL types to avro") { + withTempPath { tempDir => + val testSchema = StructType(Seq( + StructField("Name", StringType, false), + StructField("Length", IntegerType, true), + StructField("Time", TimestampType, false), + StructField("Decimal", DecimalType(10, 2), true), + StructField("Binary", BinaryType, false))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + val cityRDD = spark.sparkContext.parallelize(Seq( + Row("San Francisco", 12, new Timestamp(666), null, arrayOfByte), + Row("Palo Alto", null, new Timestamp(777), null, arrayOfByte), + Row("Munich", 8, new Timestamp(42), Decimal(3.14), arrayOfByte))) + val cityDataFrame = spark.createDataFrame(cityRDD, testSchema) + + val avroDir = tempDir + "/avro" + cityDataFrame.write.format("avro").save(avroDir) + assert(spark.read.format("avro").load(avroDir).collect().length == 3) + + // TimesStamps are converted to longs + val times = spark.read.format("avro").load(avroDir).select("Time").collect() + assert(times.map(_(0)).toSet == + Set(new Timestamp(666), new Timestamp(777), new Timestamp(42))) + + // DecimalType should be converted to string + val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect() + assert(decimals.map(_(0)).contains(new java.math.BigDecimal("3.14"))) + + // There should be a null entry + val length = spark.read.format("avro").load(avroDir).select("Length").collect() + assert(length.map(_(0)).contains(null)) + + val binary = spark.read.format("avro").load(avroDir).select("Binary").collect() + for (i <- arrayOfByte.indices) { + assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i)) + } + } + } + + test("correctly read long as date/timestamp type") { + withTempPath { tempDir => + val currentTime = new Timestamp(System.currentTimeMillis()) + val currentDate = new Date(System.currentTimeMillis()) + val schema = StructType(Seq( + StructField("_1", DateType, false), StructField("_2", TimestampType, false))) + val writeDs = Seq((currentDate, currentTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.format("avro").save(avroDir) + assert(spark.read.format("avro").load(avroDir).collect().length == 1) + + val readDs = spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)] + + assert(readDs.collect().sameElements(writeDs.collect())) + } + } + + test("support of globbed paths") { + val resourceDir = testFile(".") + val e1 = spark.read.format("avro").load(resourceDir + "../*/episodes.avro").collect() + assert(e1.length == 8) + + val e2 = spark.read.format("avro").load(resourceDir + "../../*/*/episodes.avro").collect() + assert(e2.length == 8) + } + + test("does not coerce null date/timestamp value to 0 epoch.") { + withTempPath { tempDir => + val nullTime: Timestamp = null + val nullDate: Date = null + val schema = StructType(Seq( + StructField("_1", DateType, nullable = true), + StructField("_2", TimestampType, nullable = true)) + ) + val writeDs = Seq((nullDate, nullTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.format("avro").save(avroDir) + val readValues = + spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)].collect + + assert(readValues.size == 1) + assert(readValues.head == ((nullDate, nullTime))) + } + } + + test("support user provided avro schema") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "string", + | "type" : "string", + | "doc" : "Meaningless string of characters" + | }] + |} + """.stripMargin + val result = spark + .read + .option("avroSchema", avroSchema) + .format("avro") + .load(testAvro) + .collect() + val expected = spark.read.format("avro").load(testAvro).select("string").collect() + assert(result.sameElements(expected)) + } + + test("support user provided avro schema with defaults for missing fields") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "missingField", + | "type" : "string", + | "default" : "foo" + | }] + |} + """.stripMargin + val result = spark + .read + .option("avroSchema", avroSchema) + .format("avro").load(testAvro).select("missingField").first + assert(result === Row("foo")) + } + + test("support user provided avro schema for writing nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": [{ "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing data not in the enum will throw an exception + val message = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing non-nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": { "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | } + | }] + |} + """.stripMargin + + val dfWithNull = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val df = spark.createDataFrame(dfWithNull.na.drop().rdd, + StructType(Seq(StructField("Suit", StringType, false)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing nulls without using avro union type will + // throw an exception as avro uses union type to handle null. + val message1 = intercept[SparkException] { + dfWithNull.write.format("avro") + .option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.avro.AvroRuntimeException: Not a union:")) + + // Writing df containing data not in the enum will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing nullable fixed type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "fixed2", + | "type": [{ "type": "fixed", + | "size": 2, + | "name": "fixed2" + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168).map(_.toByte)), Row(null))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message1 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 3 bytes of binary data into FIXED Type with size of 2 bytes")) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 1 byte of binary data into FIXED Type with size of 2 bytes")) + } + } + + test("support user provided avro schema for writing non-nullable fixed type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "fixed2", + | "type": { "type": "fixed", + | "size": 2, + | "name": "fixed2" + | } + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168).map(_.toByte)), Row(Array(1, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message1 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 3 bytes of binary data into FIXED Type with size of 2 bytes")) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 1 byte of binary data into FIXED Type with size of 2 bytes")) + } + } + + test("throw exception if unable to write with user provided Avro schema") { + val input: Seq[(DataType, Schema.Type)] = Seq( + (NullType, NULL), + (BooleanType, BOOLEAN), + (ByteType, INT), + (ShortType, INT), + (IntegerType, INT), + (LongType, LONG), + (FloatType, FLOAT), + (DoubleType, DOUBLE), + (BinaryType, BYTES), + (DateType, INT), + (TimestampType, LONG), + (DecimalType(4, 2), BYTES) + ) + def assertException(f: () => AvroSerializer) { + val message = intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] { + f() + }.getMessage + assert(message.contains("Cannot convert Catalyst type")) + } + + def resolveNullable(schema: Schema, nullable: Boolean): Schema = { + if (nullable && schema.getType != NULL) { + Schema.createUnion(schema, Schema.create(NULL)) + } else { + schema + } + } + for { + i <- input + j <- input + nullable <- Seq(true, false) + } if (i._2 != j._2) { + val avroType = resolveNullable(Schema.create(j._2), nullable) + val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable) + val avroMapType = resolveNullable(Schema.createMap(avroType), nullable) + val name = "foo" + val avroField = new Field(name, avroType, "", null) + val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava) + val avroRecordType = resolveNullable(recordSchema, nullable) + + val catalystType = i._1 + val catalystArrayType = ArrayType(catalystType, nullable) + val catalystMapType = MapType(StringType, catalystType, nullable) + val catalystStructType = StructType(Seq(StructField(name, catalystType, nullable))) + + for { + avro <- Seq(avroType, avroArrayType, avroMapType, avroRecordType) + catalyst <- Seq(catalystType, catalystArrayType, catalystMapType, catalystStructType) + } { + assertException(() => new AvroSerializer(catalyst, avro, nullable)) + } + } + } + + test("reading from invalid path throws exception") { + + // Directory given has no avro files + intercept[AnalysisException] { + withTempPath(dir => spark.read.format("avro").load(dir.getCanonicalPath)) + } + + intercept[AnalysisException] { + spark.read.format("avro").load("very/invalid/path/123.avro") + } + + // In case of globbed path that can't be matched to anything, another exception is thrown (and + // exception message is helpful) + intercept[AnalysisException] { + spark.read.format("avro").load("*/*/*/*/*/*/*/something.avro") + } + + intercept[FileNotFoundException] { + withTempPath { dir => + FileUtils.touch(new File(dir, "test")) + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { + spark.read.format("avro").load(dir.toString) + } + } + } + + intercept[FileNotFoundException] { + withTempPath { dir => + FileUtils.touch(new File(dir, "test")) + + spark + .read + .option("ignoreExtension", false) + .format("avro") + .load(dir.toString) + } + } + } + + test("SQL test insert overwrite") { + withTempPath { tempDir => + val tempEmptyDir = s"$tempDir/sqlOverwrite" + // Create a temp directory for table that will be overwritten + new File(tempEmptyDir).mkdirs() + spark.sql( + s""" + |CREATE TEMPORARY VIEW episodes + |USING avro + |OPTIONS (path "${episodesAvro}") + """.stripMargin.replaceAll("\n", " ")) + spark.sql( + s""" + |CREATE TEMPORARY VIEW episodesEmpty + |(name string, air_date string, doctor int) + |USING avro + |OPTIONS (path "$tempEmptyDir") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM episodes").collect().length === 8) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().isEmpty) + + spark.sql( + s""" + |INSERT OVERWRITE TABLE episodesEmpty + |SELECT * FROM episodes + """.stripMargin.replaceAll("\n", " ")) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().length == 8) + } + } + + test("test save and load") { + // Test if load works as expected + withTempPath { tempDir => + val df = spark.read.format("avro").load(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").save(tempSaveDir) + val newDf = spark.read.format("avro").load(tempSaveDir) + assert(newDf.count == 8) + } + } + + test("test load with non-Avro file") { + // Test if load works as expected + withTempPath { tempDir => + val df = spark.read.format("avro").load(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.format("avro").save(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { + val newDf = spark.read.format("avro").load(tempSaveDir) + assert(newDf.count() == 8) + } + } + } + + test("read avro with user defined schema: read partial columns") { + val partialColumns = StructType(Seq( + StructField("string", StringType, false), + StructField("simple_map", MapType(StringType, IntegerType), false), + StructField("complex_map", MapType(StringType, MapType(StringType, StringType)), false), + StructField("union_string_null", StringType, true), + StructField("union_int_long_null", LongType, true), + StructField("fixed3", BinaryType, true), + StructField("fixed2", BinaryType, true), + StructField("enum", StringType, false), + StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), + StructField("array_of_boolean", ArrayType(BooleanType), false), + StructField("bytes", BinaryType, true))) + val withSchema = spark.read.schema(partialColumns).format("avro").load(testAvro).collect() + val withOutSchema = spark + .read + .format("avro") + .load(testAvro) + .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", + "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") + .collect() + assert(withSchema.sameElements(withOutSchema)) + } + + test("read avro with user defined schema: read non-exist columns") { + val schema = + StructType( + Seq( + StructField("non_exist_string", StringType, true), + StructField( + "record", + StructType(Seq( + StructField("non_exist_field", StringType, false), + StructField("non_exist_field2", StringType, false))), + false))) + val withEmptyColumn = spark.read.schema(schema).format("avro").load(testAvro).collect() + + assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) + } + + test("read avro file partitioned") { + withTempPath { dir => + val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.format("avro").save(outputDir) + val input = spark.read.format("avro").load(outputDir) + assert(input.collect.toSet.size === 1024 * 3 + 1) + assert(input.rdd.partitions.size > 2) + } + } + + case class NestedBottom(id: Int, data: String) + + case class NestedMiddle(id: Int, data: NestedBottom) + + case class NestedTop(id: Int, data: NestedMiddle) + + test("Validate namespace in avro file that has nested records with the same name") { + withTempPath { dir => + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + writeDf.write.format("avro").save(dir.toString) + val schema = getAvroSchemaStringFromFiles(dir.toString) + assert(schema.contains("\"namespace\":\"topLevelRecord\"")) + assert(schema.contains("\"namespace\":\"topLevelRecord.data\"")) + } + } + + test("saving avro that has nested records with the same name") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + val outputFolder = s"$tempDir/duplicate_names/" + writeDf.write.format("avro").save(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.format("avro").load(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + test("check namespace - toAvroType") { + val sparkSchema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("address", StructType(Seq( + StructField("city", StringType, nullable = false), + StructField("state", StringType, nullable = false))), + nullable = false))) + val employeeType = SchemaConverters.toAvroType(sparkSchema, + recordName = "employee", + nameSpace = "foo.bar") + + assert(employeeType.getFullName == "foo.bar.employee") + assert(employeeType.getName == "employee") + assert(employeeType.getNamespace == "foo.bar") + + val addressType = employeeType.getField("address").schema() + assert(addressType.getFullName == "foo.bar.employee.address") + assert(addressType.getName == "address") + assert(addressType.getNamespace == "foo.bar.employee") + } + + test("check empty namespace - toAvroType") { + val sparkSchema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("address", StructType(Seq( + StructField("city", StringType, nullable = false), + StructField("state", StringType, nullable = false))), + nullable = false))) + val employeeType = SchemaConverters.toAvroType(sparkSchema, + recordName = "employee") + + assert(employeeType.getFullName == "employee") + assert(employeeType.getName == "employee") + assert(employeeType.getNamespace == null) + + val addressType = employeeType.getField("address").schema() + assert(addressType.getFullName == "employee.address") + assert(addressType.getName == "address") + assert(addressType.getNamespace == "employee") + } + + case class NestedMiddleArray(id: Int, data: Array[NestedBottom]) + + case class NestedTopArray(id: Int, data: NestedMiddleArray) + + test("saving avro that has nested records with the same name inside an array") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopArray(1, NestedMiddleArray(2, Array( + NestedBottom(3, "1"), NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_array/" + writeDf.write.format("avro").save(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.format("avro").load(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleMap(id: Int, data: Map[String, NestedBottom]) + + case class NestedTopMap(id: Int, data: NestedMiddleMap) + + test("saving avro that has nested records with the same name inside a map") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopMap(1, NestedMiddleMap(2, Map( + "1" -> NestedBottom(3, "1"), "2" -> NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_map/" + writeDf.write.format("avro").save(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.format("avro").load(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + test("SPARK-24805: do not ignore files without .avro extension by default") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes" + val df1 = spark.read.format("avro").load(fileWithoutExtension) + assert(df1.count == 8) + + val schema = new StructType() + .add("title", StringType) + .add("air_date", StringType) + .add("doctor", IntegerType) + val df2 = spark.read.schema(schema).format("avro").load(fileWithoutExtension) + assert(df2.count == 8) + } + } + + test("SPARK-24836: checking the ignoreExtension option") { + withTempPath { tempDir => + val df = spark.read.format("avro").load(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.format("avro").save(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + val newDf = spark + .read + .option("ignoreExtension", false) + .format("avro") + .load(tempSaveDir) + + assert(newDf.count == 8) + } + } + + test("SPARK-24836: ignoreExtension must override hadoop's config") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val hadoopConf = spark.sessionState.newHadoopConf() + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { + val newDf = spark + .read + .option("ignoreExtension", "true") + .format("avro") + .load(s"${dir.getCanonicalPath}/episodes") + assert(newDf.count() == 8) + } + } + } + + test("SPARK-24881: write with compression - avro options") { + def getCodec(dir: String): Option[String] = { + val files = new File(dir) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + files.map { file => + val reader = new DataFileReader(file, new GenericDatumReader[Any]()) + val r = reader.getMetaString("avro.codec") + r + }.map(v => if (v == "null") "uncompressed" else v).headOption + } + def checkCodec(df: DataFrame, dir: String, codec: String): Unit = { + val subdir = s"$dir/$codec" + df.write.option("compression", codec).format("avro").save(subdir) + assert(getCodec(subdir) == Some(codec)) + } + withTempPath { dir => + val path = dir.toString + val df = spark.read.format("avro").load(testAvro) + + checkCodec(df, path, "uncompressed") + checkCodec(df, path, "deflate") + checkCodec(df, path, "snappy") + checkCodec(df, path, "bzip2") + checkCodec(df, path, "xz") + } + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 8512496e5fe52..09a2cd83aed6b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.jdbc +import java.math.BigDecimal import java.sql.{Connection, Date, Timestamp} import java.util.{Properties, TimeZone} -import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.execution.{RowDataSourceScanExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -86,7 +88,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo conn.prepareStatement( "CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 NUMBER(1))").executeUpdate() conn.prepareStatement( - "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)").executeUpdate() + "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)") + .executeUpdate() conn.commit() sql( @@ -108,15 +111,36 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate() + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))") + .executeUpdate() conn.prepareStatement( "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate() conn.commit() - conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate() + conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)") + .executeUpdate() conn.commit() - } + conn.prepareStatement("CREATE TABLE datetimePartitionTest (id NUMBER(10), d DATE, t TIMESTAMP)") + .executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(1, {d '2018-07-06'}, {ts '2018-07-06 05:50:00'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(2, {d '2018-07-06'}, {ts '2018-07-06 08:10:08'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(3, {d '2018-07-08'}, {ts '2018-07-08 13:32:01'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(4, {d '2018-07-12'}, {ts '2018-07-12 09:51:15'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.commit() + } test("SPARK-16625 : Importing Oracle numeric types") { val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties) @@ -399,4 +423,54 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(values.getDouble(0) === 1.1) assert(values.getFloat(1) === 2.2f) } + + test("SPARK-22814 support date/timestamp types in partitionColumn") { + val expectedResult = Set( + (1, "2018-07-06", "2018-07-06 05:50:00"), + (2, "2018-07-06", "2018-07-06 08:10:08"), + (3, "2018-07-08", "2018-07-08 13:32:01"), + (4, "2018-07-12", "2018-07-12 09:51:15") + ).map { case (id, date, timestamp) => + Row(BigDecimal.valueOf(id), Date.valueOf(date), Timestamp.valueOf(timestamp)) + } + + // DateType partition column + val df1 = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "datetimePartitionTest") + .option("partitionColumn", "d") + .option("lowerBound", "2018-07-06") + .option("upperBound", "2018-07-20") + .option("numPartitions", 3) + .load() + + df1.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"D" < '2018-07-10' or "D" is null""", + """"D" >= '2018-07-10' AND "D" < '2018-07-14'""", + """"D" >= '2018-07-14'""")) + } + assert(df1.collect.toSet === expectedResult) + + // TimestampType partition column + val df2 = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "datetimePartitionTest") + .option("partitionColumn", "t") + .option("lowerBound", "2018-07-04 03:30:00.0") + .option("upperBound", "2018-07-27 14:11:05.0") + .option("numPartitions", 2) + .load() + + df2.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"T" < '2018-07-15 20:50:32.5' or "T" is null""", + """"T" >= '2018-07-15 20:50:32.5'""")) + } + assert(df2.collect.toSet === expectedResult) + } } diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index a742b8d6dbddb..f80f8e3a0183d 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -95,11 +95,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.scala-lang scala-library diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 16bbc6db641ca..8588e8be052eb 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -29,10 +29,11 @@ spark-sql-kafka-0-10_2.11 sql-kafka-0-10 - 0.10.0.1 + + 2.0.0 jar - Kafka 0.10 Source for Structured Streaming + Kafka 0.10+ Source for Structured Streaming http://spark.apache.org/ @@ -73,6 +74,20 @@ kafka_${scala.binary.version} ${kafka.version} test + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + net.sf.jopt-simple @@ -80,6 +95,12 @@ 3.2 test + + org.eclipse.jetty + jetty-servlet + ${jetty.version} + test + org.scalacheck scalacheck_${scala.binary.version} @@ -108,13 +129,4 @@ target/scala-${scala.binary.version}/test-classes - - - scala-2.12 - - 0.10.1.1 - - - - diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala index 571140b0afbc7..cd680adf44365 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala @@ -33,8 +33,12 @@ private[kafka010] object CachedKafkaProducer extends Logging { private type Producer = KafkaProducer[Array[Byte], Array[Byte]] + private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10) + private lazy val cacheExpireTimeout: Long = - SparkEnv.get.conf.getTimeAsMs("spark.kafka.producer.cache.timeout", "10m") + Option(SparkEnv.get).map(_.conf.getTimeAsMs( + "spark.kafka.producer.cache.timeout", + s"${defaultCacheExpireTimeout}ms")).getOrElse(defaultCacheExpireTimeout) private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] { override def load(config: Seq[(String, Object)]): Producer = { @@ -102,7 +106,7 @@ private[kafka010] object CachedKafkaProducer extends Logging { } } - private def clear(): Unit = { + private[kafka010] def clear(): Unit = { logInfo("Cleaning up guava cache.") guavaCache.invalidateAll() } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala index 868edb5dcdc0c..92b13f2b555d1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -29,6 +29,11 @@ import org.json4s.jackson.Serialization */ private object JsonUtils { private implicit val formats = Serialization.formats(NoTypeHints) + implicit val ordering = new Ordering[TopicPartition] { + override def compare(x: TopicPartition, y: TopicPartition): Int = { + Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition)) + } + } /** * Read TopicPartitions from json string @@ -51,7 +56,7 @@ private object JsonUtils { * Write TopicPartitions as json string */ def partitions(partitions: Iterable[TopicPartition]): String = { - val result = new HashMap[String, List[Int]] + val result = HashMap.empty[String, List[Int]] partitions.foreach { tp => val parts: List[Int] = result.getOrElse(tp.topic, Nil) result += tp.topic -> (tp.partition::parts) @@ -80,19 +85,31 @@ private object JsonUtils { * Write per-TopicPartition offsets as json string */ def partitionOffsets(partitionOffsets: Map[TopicPartition, Long]): String = { - val result = new HashMap[String, HashMap[Int, Long]]() - implicit val ordering = new Ordering[TopicPartition] { - override def compare(x: TopicPartition, y: TopicPartition): Int = { - Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition)) - } - } + val result = HashMap.empty[String, HashMap[Int, Long]] val partitions = partitionOffsets.keySet.toSeq.sorted // sort for more determinism partitions.foreach { tp => val off = partitionOffsets(tp) - val parts = result.getOrElse(tp.topic, new HashMap[Int, Long]) + val parts = result.getOrElse(tp.topic, HashMap.empty[Int, Long]) parts += tp.partition -> off result += tp.topic -> parts } Serialization.write(result) } + + /** + * Write per-topic partition lag as json string + */ + def partitionLags( + latestOffsets: Map[TopicPartition, Long], + processedOffsets: Map[TopicPartition, Long]): String = { + val result = HashMap.empty[String, HashMap[Int, Long]] + val partitions = latestOffsets.keySet.toSeq.sorted + partitions.foreach { tp => + val lag = latestOffsets(tp) - processedOffsets.getOrElse(tp, 0L) + val parts = result.getOrElse(tp.topic, HashMap.empty[Int, Long]) + parts += tp.partition -> lag + result += tp.topic -> parts + } + Serialization.write(Map("lag" -> result)) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala similarity index 70% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala index f26c134c2f6e9..1753a28fba2fb 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala @@ -25,15 +25,15 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType /** - * A [[ContinuousReader]] for data from kafka. + * A [[ContinuousReadSupport]] for data from kafka. * * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. @@ -46,70 +46,49 @@ import org.apache.spark.sql.types.StructType * scenarios, where some offsets after the specified initial ones can't be * properly read. */ -class KafkaContinuousReader( +class KafkaContinuousReadSupport( offsetReader: KafkaOffsetReader, kafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { - - private lazy val session = SparkSession.getActiveSession.get - private lazy val sc = session.sparkContext + extends ContinuousReadSupport with Logging { private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong - // Initialized when creating reader factories. If this diverges from the partitions at the latest - // offsets, we need to reconfigure. - // Exposed outside this object only for unit tests. - @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ - - override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - - private var offset: Offset = _ - override def setStartOffset(start: ju.Optional[Offset]): Unit = { - offset = start.orElse { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets + override def initialOffset(): Offset = { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) } + logInfo(s"Initial offsets: $offsets") + offsets } - override def getStartOffset(): Offset = offset + override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss) + } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { - import scala.collection.JavaConverters._ - - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - knownPartitions = startOffsets.keySet - + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousDataReaderFactory( + KafkaContinuousInputPartition( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[DataReaderFactory[UnsafeRow]] - }.asJava + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + KafkaContinuousReaderFactory } /** Stop this source and free any resources it has allocated. */ @@ -126,8 +105,9 @@ class KafkaContinuousReader( KafkaSourceOffset(mergedMap) } - override def needsReconfiguration(): Boolean = { - knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + override def needsReconfiguration(config: ScanConfig): Boolean = { + val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions + offsetReader.fetchLatestOffsets().keySet != knownPartitions } override def toString(): String = s"KafkaSource[$offsetReader]" @@ -146,7 +126,7 @@ class KafkaContinuousReader( } /** - * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * An input partition for continuous Kafka processing. This will be serialized and transformed * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. @@ -156,27 +136,56 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousDataReaderFactory( +case class KafkaContinuousInputPartition( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { - - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { - val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] - require(kafkaOffset.topicPartition == topicPartition, - s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousDataReader( - topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + failOnDataLoss: Boolean) extends InputPartition + +object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[KafkaContinuousInputPartition] + new KafkaContinuousPartitionReader( + p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss) } +} + +class KafkaContinuousScanConfigBuilder( + schema: StructType, + startOffset: Offset, + offsetReader: KafkaOffsetReader, + reportDataLoss: String => Unit) + extends ScanConfigBuilder { + + override def build(): ScanConfig = { + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader( - topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + KafkaContinuousScanConfig(schema, startOffsets) } } +case class KafkaContinuousScanConfig( + readSchema: StructType, + startOffsets: Map[TopicPartition, Long]) + extends ScanConfig { + + // Created when building the scan config builder. If this diverges from the partitions at the + // latest offsets, we need to reconfigure the kafka read support. + def knownPartitions: Set[TopicPartition] = startOffsets.keySet +} + /** * A per-task data reader for continuous Kafka processing. * @@ -187,12 +196,12 @@ case class KafkaContinuousDataReaderFactory( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousDataReader( +class KafkaContinuousPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter @@ -214,11 +223,11 @@ class KafkaContinuousDataReader( } catch { // We didn't read within the timeout. We're supposed to block indefinitely for new data, so // swallow and ignore this. - case _: TimeoutException => + case _: TimeoutException | _: org.apache.kafka.common.errors.TimeoutException => // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, // or if it's the endpoint of the data range (i.e. the "true" next offset). - case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => + case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => val range = consumer.getAvailableOffsetRange() if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) { // retry diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 48508d057a540..ceb9e318b283b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -33,9 +33,19 @@ import org.apache.spark.util.UninterruptibleThread private[kafka010] sealed trait KafkaDataConsumer { /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. + * Get the record for the given offset if available. + * + * If the record is invisible (either a + * transaction message, or an aborted message when the consumer's `isolation.level` is + * `read_committed`), it will be skipped and this method will try to fetch next available record + * within [offset, untilOffset). + * + * This method also will try its best to detect data loss. If `failOnDataLoss` is `true`, it will + * throw an exception when we detect an unavailable offset. If `failOnDataLoss` is `false`, this + * method will try to fetch next available record within [offset, untilOffset). + * + * When this method tries to skip offsets due to either invisible messages or data loss and + * reaches `untilOffset`, it will return `null`. * * @param offset the offset to fetch. * @param untilOffset the max offset to fetch. Exclusive. @@ -80,6 +90,83 @@ private[kafka010] case class InternalKafkaConsumer( kafkaParams: ju.Map[String, Object]) extends Logging { import InternalKafkaConsumer._ + /** + * The internal object to store the fetched data from Kafka consumer and the next offset to poll. + * + * @param _records the pre-fetched Kafka records. + * @param _nextOffsetInFetchedData the next offset in `records`. We use this to verify if we + * should check if the pre-fetched data is still valid. + * @param _offsetAfterPoll the Kafka offset after calling `poll`. We will use this offset to + * poll when `records` is drained. + */ + private case class FetchedData( + private var _records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + private var _nextOffsetInFetchedData: Long, + private var _offsetAfterPoll: Long) { + + def withNewPoll( + records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + offsetAfterPoll: Long): FetchedData = { + this._records = records + this._nextOffsetInFetchedData = UNKNOWN_OFFSET + this._offsetAfterPoll = offsetAfterPoll + this + } + + /** Whether there are more elements */ + def hasNext: Boolean = _records.hasNext + + /** Move `records` forward and return the next record. */ + def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { + val record = _records.next() + _nextOffsetInFetchedData = record.offset + 1 + record + } + + /** Move `records` backward and return the previous record. */ + def previous(): ConsumerRecord[Array[Byte], Array[Byte]] = { + assert(_records.hasPrevious, "fetchedData cannot move back") + val record = _records.previous() + _nextOffsetInFetchedData = record.offset + record + } + + /** Reset the internal pre-fetched data. */ + def reset(): Unit = { + _records = ju.Collections.emptyListIterator() + } + + /** + * Returns the next offset in `records`. We use this to verify if we should check if the + * pre-fetched data is still valid. + */ + def nextOffsetInFetchedData: Long = _nextOffsetInFetchedData + + /** + * Returns the next offset to poll after draining the pre-fetched records. + */ + def offsetAfterPoll: Long = _offsetAfterPoll + } + + /** + * The internal object returned by the `fetchRecord` method. If `record` is empty, it means it is + * invisible (either a transaction message, or an aborted message when the consumer's + * `isolation.level` is `read_committed`), and the caller should use `nextOffsetToFetch` to fetch + * instead. + */ + private case class FetchedRecord( + var record: ConsumerRecord[Array[Byte], Array[Byte]], + var nextOffsetToFetch: Long) { + + def withRecord( + record: ConsumerRecord[Array[Byte], Array[Byte]], + nextOffsetToFetch: Long): FetchedRecord = { + this.record = record + this.nextOffsetToFetch = nextOffsetToFetch + this + } + } + private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] @volatile private var consumer = createConsumer @@ -90,10 +177,21 @@ private[kafka010] case class InternalKafkaConsumer( /** indicate whether this consumer is going to be stopped in the next release */ @volatile var markedForClose = false - /** Iterator to the already fetch data */ - @volatile private var fetchedData = - ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET + /** + * The fetched data returned from Kafka consumer. This is a reusable private object to avoid + * memory allocation. + */ + private val fetchedData = FetchedData( + ju.Collections.emptyListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + UNKNOWN_OFFSET, + UNKNOWN_OFFSET) + + /** + * The fetched record returned from the `fetchRecord` method. This is a reusable private object to + * avoid memory allocation. + */ + private val fetchedRecord: FetchedRecord = FetchedRecord(null, UNKNOWN_OFFSET) + /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { @@ -125,20 +223,7 @@ private[kafka010] case class InternalKafkaConsumer( AvailableOffsetRange(earliestOffset, latestOffset) } - /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. - * - * @param offset the offset to fetch. - * @param untilOffset the max offset to fetch. Exclusive. - * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. - * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at - * offset if available, or throw exception.when `failOnDataLoss` is `false`, - * this method will either return record at offset if available, or return - * the next earliest available record less than untilOffset, or null. It - * will not throw any exception. - */ + /** @see [[KafkaDataConsumer.get]] */ def get( offset: Long, untilOffset: Long, @@ -147,21 +232,32 @@ private[kafka010] case class InternalKafkaConsumer( ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible { require(offset < untilOffset, s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") - logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") + logDebug(s"Get $groupId $topicPartition nextOffset ${fetchedData.nextOffsetInFetchedData} " + + s"requested $offset") // The following loop is basically for `failOnDataLoss = false`. When `failOnDataLoss` is // `false`, first, we will try to fetch the record at `offset`. If no such record exists, then // we will move to the next available offset within `[offset, untilOffset)` and retry. // If `failOnDataLoss` is `true`, the loop body will be executed only once. var toFetchOffset = offset - var consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]] = null + var fetchedRecord: FetchedRecord = null // We want to break out of the while loop on a successful fetch to avoid using "return" - // which may causes a NonLocalReturnControl exception when this method is used as a function. + // which may cause a NonLocalReturnControl exception when this method is used as a function. var isFetchComplete = false while (toFetchOffset != UNKNOWN_OFFSET && !isFetchComplete) { try { - consumerRecord = fetchData(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss) - isFetchComplete = true + fetchedRecord = fetchRecord(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss) + if (fetchedRecord.record != null) { + isFetchComplete = true + } else { + toFetchOffset = fetchedRecord.nextOffsetToFetch + if (toFetchOffset >= untilOffset) { + fetchedData.reset() + toFetchOffset = UNKNOWN_OFFSET + } else { + logDebug(s"Skipped offsets [$offset, $toFetchOffset]") + } + } } catch { case e: OffsetOutOfRangeException => // When there is some error thrown, it's better to use a new consumer to drop all cached @@ -174,9 +270,9 @@ private[kafka010] case class InternalKafkaConsumer( } if (isFetchComplete) { - consumerRecord + fetchedRecord.record } else { - resetFetchedData() + fetchedData.reset() null } } @@ -239,57 +335,73 @@ private[kafka010] case class InternalKafkaConsumer( } /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. + * Get the fetched record for the given offset if available. + * + * If the record is invisible (either a transaction message, or an aborted message when the + * consumer's `isolation.level` is `read_committed`), it will return a `FetchedRecord` with the + * next offset to fetch. + * + * This method also will try the best to detect data loss. If `failOnDataLoss` is true`, it will + * throw an exception when we detect an unavailable offset. If `failOnDataLoss` is `false`, this + * method will return `null` if the next available record is within [offset, untilOffset). * * @throws OffsetOutOfRangeException if `offset` is out of range * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` milliseconds. */ - private def fetchData( + private def fetchRecord( offset: Long, untilOffset: Long, pollTimeoutMs: Long, - failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { - if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) { - // This is the first fetch, or the last pre-fetched data has been drained. - // Seek to the offset because we may call seekToBeginning or seekToEnd before this. - seek(offset) - poll(pollTimeoutMs) - } - - if (!fetchedData.hasNext()) { - // We cannot fetch anything after `poll`. Two possible cases: - // - `offset` is out of range so that Kafka returns nothing. Just throw - // `OffsetOutOfRangeException` to let the caller handle it. - // - Cannot fetch any data before timeout. TimeoutException will be thrown. - val range = getAvailableOffsetRange() - if (offset < range.earliest || offset >= range.latest) { - throw new OffsetOutOfRangeException( - Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + failOnDataLoss: Boolean): FetchedRecord = { + if (offset != fetchedData.nextOffsetInFetchedData) { + // This is the first fetch, or the fetched data has been reset. + // Fetch records from Kafka and update `fetchedData`. + fetchData(offset, pollTimeoutMs) + } else if (!fetchedData.hasNext) { // The last pre-fetched data has been drained. + if (offset < fetchedData.offsetAfterPoll) { + // Offsets in [offset, fetchedData.offsetAfterPoll) are invisible. Return a record to ask + // the next call to start from `fetchedData.offsetAfterPoll`. + fetchedData.reset() + return fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll) } else { - throw new TimeoutException( - s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + // Fetch records from Kafka and update `fetchedData`. + fetchData(offset, pollTimeoutMs) } + } + + if (!fetchedData.hasNext) { + // When we reach here, we have already tried to poll from Kafka. As `fetchedData` is still + // empty, all messages in [offset, fetchedData.offsetAfterPoll) are invisible. Return a + // record to ask the next call to start from `fetchedData.offsetAfterPoll`. + assert(offset <= fetchedData.offsetAfterPoll, + s"seek to $offset and poll but the offset was reset to ${fetchedData.offsetAfterPoll}") + fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll) } else { val record = fetchedData.next() - nextOffsetInFetchedData = record.offset + 1 // In general, Kafka uses the specified offset as the start point, and tries to fetch the next // available offset. Hence we need to handle offset mismatch. if (record.offset > offset) { + val range = getAvailableOffsetRange() + if (range.earliest <= offset) { + // `offset` is still valid but the corresponding message is invisible. We should skip it + // and jump to `record.offset`. Here we move `fetchedData` back so that the next call of + // `fetchRecord` can just return `record` directly. + fetchedData.previous() + return fetchedRecord.withRecord(null, record.offset) + } // This may happen when some records aged out but their offsets already got verified if (failOnDataLoss) { reportDataLoss(true, s"Cannot fetch records in [$offset, ${record.offset})") // Never happen as "reportDataLoss" will throw an exception - null + throw new IllegalStateException( + "reportDataLoss didn't throw an exception when 'failOnDataLoss' is true") + } else if (record.offset >= untilOffset) { + reportDataLoss(false, s"Skip missing records in [$offset, $untilOffset)") + // Set `nextOffsetToFetch` to `untilOffset` to finish the current batch. + fetchedRecord.withRecord(null, untilOffset) } else { - if (record.offset >= untilOffset) { - reportDataLoss(false, s"Skip missing records in [$offset, $untilOffset)") - null - } else { - reportDataLoss(false, s"Skip missing records in [$offset, ${record.offset})") - record - } + reportDataLoss(false, s"Skip missing records in [$offset, ${record.offset})") + fetchedRecord.withRecord(record, fetchedData.nextOffsetInFetchedData) } } else if (record.offset < offset) { // This should not happen. If it does happen, then we probably misunderstand Kafka internal @@ -297,7 +409,7 @@ private[kafka010] case class InternalKafkaConsumer( throw new IllegalStateException( s"Tried to fetch $offset but the returned record offset was ${record.offset}") } else { - record + fetchedRecord.withRecord(record, fetchedData.nextOffsetInFetchedData) } } } @@ -306,13 +418,7 @@ private[kafka010] case class InternalKafkaConsumer( private def resetConsumer(): Unit = { consumer.close() consumer = createConsumer - resetFetchedData() - } - - /** Reset the internal pre-fetched data. */ - private def resetFetchedData(): Unit = { - nextOffsetInFetchedData = UNKNOWN_OFFSET - fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + fetchedData.reset() } /** @@ -346,11 +452,40 @@ private[kafka010] case class InternalKafkaConsumer( consumer.seek(topicPartition, offset) } - private def poll(pollTimeoutMs: Long): Unit = { + /** + * Poll messages from Kafka starting from `offset` and update `fetchedData`. `fetchedData` may be + * empty if the Kafka consumer fetches some messages but all of them are not visible messages + * (either transaction messages, or aborted messages when `isolation.level` is `read_committed`). + * + * @throws OffsetOutOfRangeException if `offset` is out of range. + * @throws TimeoutException if the consumer position is not changed after polling. It means the + * consumer polls nothing before timeout. + */ + private def fetchData(offset: Long, pollTimeoutMs: Long): Unit = { + // Seek to the offset because we may call seekToBeginning or seekToEnd before this. + seek(offset) val p = consumer.poll(pollTimeoutMs) val r = p.records(topicPartition) logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") - fetchedData = r.iterator + val offsetAfterPoll = consumer.position(topicPartition) + logDebug(s"Offset changed from $offset to $offsetAfterPoll after polling") + fetchedData.withNewPoll(r.listIterator, offsetAfterPoll) + if (!fetchedData.hasNext) { + // We cannot fetch anything after `poll`. Two possible cases: + // - `offset` is out of range so that Kafka returns nothing. `OffsetOutOfRangeException` will + // be thrown. + // - Cannot fetch any data before timeout. `TimeoutException` will be thrown. + // - Fetched something but all of them are not invisible. This is a valid case and let the + // caller handles this. + val range = getAvailableOffsetRange() + if (offset < range.earliest || offset >= range.latest) { + throw new OffsetOutOfRangeException( + Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + } else if (offset == offsetAfterPoll) { + throw new TimeoutException( + s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + } + } } } @@ -395,7 +530,7 @@ private[kafka010] object KafkaDataConsumer extends Logging { // likely running on a beefy machine that can handle a large number of simultaneously // active consumers. - if (entry.getValue.inUse == false && this.size > capacity) { + if (!entry.getValue.inUse && this.size > capacity) { logWarning( s"KafkaConsumer cache hitting max capacity of $capacity, " + s"removing consumer for ${entry.getKey}") diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala similarity index 76% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala index 2ed49ba3f5495..70f37e32e78db 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala @@ -21,25 +21,26 @@ import java.{util => ju} import java.io._ import java.nio.charset.StandardCharsets -import scala.collection.JavaConverters._ - import org.apache.commons.io.IOUtils +import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset, SupportsCustomReaderMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread /** - * A [[MicroBatchReader]] that reads data from Kafka. + * A [[MicroBatchReadSupport]] that reads data from Kafka. * * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For @@ -54,54 +55,57 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] class KafkaMicroBatchReader( +private[kafka010] class KafkaMicroBatchReadSupport( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], options: DataSourceOptions, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - - private var startPartitionOffsets: PartitionOffsetMap = _ - private var endPartitionOffsets: PartitionOffsetMap = _ + extends RateControlMicroBatchReadSupport with SupportsCustomReaderMetrics with Logging { private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", - SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) private val rangeCalculator = KafkaOffsetRangeCalculator(options) + + private var endPartitionOffsets: KafkaSourceOffset = _ + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running * `KafkaConsumer.poll` may hang forever (KAFKA-1894). */ - private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() - - override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { - // Make sure initialPartitionOffsets is initialized - initialPartitionOffsets - - startPartitionOffsets = Option(start.orElse(null)) - .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) - .getOrElse(initialPartitionOffsets) - - endPartitionOffsets = Option(end.orElse(null)) - .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) - .getOrElse { - val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() - maxOffsetsPerTrigger.map { maxOffsets => - rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) - }.getOrElse { - latestPartitionOffsets - } - } + override def initialOffset(): Offset = { + KafkaSourceOffset(getOrCreateInitialPartitionOffsets()) } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def latestOffset(start: Offset): Offset = { + val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + endPartitionOffsets = KafkaSourceOffset(maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + }) + endPartitionOffsets + } + + override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets + // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -142,34 +146,34 @@ private[kafka010] class KafkaMicroBatchReader( val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size // Generate factories based on the offset ranges - val factories = offsetRanges.map { range => - new KafkaMicroBatchDataReaderFactory( + offsetRanges.map { range => + KafkaMicroBatchInputPartition( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) - } - factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + }.toArray } - override def getStartOffset: Offset = { - KafkaSourceOffset(startPartitionOffsets) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + KafkaMicroBatchReaderFactory } - override def getEndOffset: Offset = { - KafkaSourceOffset(endPartitionOffsets) + // TODO: figure out the life cycle of custom metrics, and make this method take `ScanConfig` as + // a parameter. + override def getCustomMetrics(): CustomMetrics = { + KafkaCustomMetrics( + kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets.partitionToOffsets) } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema - override def commit(end: Offset): Unit = {} override def stop(): Unit = { kafkaOffsetReader.close() } - override def toString(): String = s"Kafka[$kafkaOffsetReader]" + override def toString(): String = s"KafkaV2[$kafkaOffsetReader]" /** * Read initial partition offsets from the checkpoint, or decide the offsets and write them to @@ -299,27 +303,29 @@ private[kafka010] class KafkaMicroBatchReader( } } -/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReaderFactory( +/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { - - override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray + reuseKafkaConsumer: Boolean) extends InputPartition - override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) +private[kafka010] object KafkaMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[KafkaMicroBatchInputPartition] + KafkaMicroBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, + p.failOnDataLoss, p.reuseKafkaConsumer) + } } -/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReader( +/** A [[PartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends PartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) @@ -335,6 +341,7 @@ private[kafka010] case class KafkaMicroBatchDataReader( val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) if (record != null) { nextRow = converter.toUnsafeRow(record) + nextOffset = record.offset + 1 true } else { false @@ -346,7 +353,6 @@ private[kafka010] case class KafkaMicroBatchDataReader( override def get(): UnsafeRow = { assert(nextRow != null) - nextOffset += 1 nextRow } @@ -378,3 +384,18 @@ private[kafka010] case class KafkaMicroBatchDataReader( } } } + +/** + * Currently reports per topic-partition lag. + * This is the difference between the offset of the latest available data + * in a topic-partition and the latest offset that has been processed. + */ +private[kafka010] case class KafkaCustomMetrics( + latestOffsets: Map[TopicPartition, Long], + processedOffsets: Map[TopicPartition, Long]) extends CustomMetrics { + override def json(): String = { + JsonUtils.partitionLags(latestOffsets, processedOffsets) + } + + override def toString: String = json() +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..82066697cb95a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -75,7 +75,17 @@ private[kafka010] class KafkaOffsetReader( * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - protected var consumer = createConsumer() + @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null + + protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer == null) { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + _consumer = consumerStrategy.createConsumer(newKafkaParams) + } + _consumer + } private val maxOffsetFetchAttempts = readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -95,9 +105,7 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - runUninterruptibly { - consumer.close() - } + if (_consumer != null) runUninterruptibly { stopConsumer() } kafkaReaderThread.shutdown() } @@ -304,19 +312,14 @@ private[kafka010] class KafkaOffsetReader( } } - /** - * Create a consumer using the new generated group id. We always use a new consumer to avoid - * just using a broken consumer to retry on Kafka errors, which likely will fail again. - */ - private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { - val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) - newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) - consumerStrategy.createConsumer(newKafkaParams) + private def stopConsumer(): Unit = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer != null) _consumer.close() } private def resetConsumer(): Unit = synchronized { - consumer.close() - consumer = createConsumer() + stopConsumer() + _consumer = null // will automatically get reinitialized again } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index 7103709969c18..9d856c9494e10 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -48,7 +48,9 @@ private[kafka010] class KafkaRelation( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sqlContext.sparkContext.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sqlContext.sparkContext.conf.getTimeAsSeconds( + "spark.network.timeout", + "120s") * 1000L).toString ).toLong override def schema: StructType = KafkaOffsetReader.kafkaSchema @@ -115,7 +117,7 @@ private[kafka010] class KafkaRelation( DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), cr.timestampType.id) } - sqlContext.internalCreateDataFrame(rdd, schema).rdd + sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd } private def getPartitionOffsets( diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1c7b3a29a861f..66ec7e0cd084a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -84,7 +84,7 @@ private[kafka010] class KafkaSource( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L).toString ).toLong private val maxOffsetsPerTrigger = @@ -215,7 +215,7 @@ private[kafka010] class KafkaSource( } if (start.isDefined && start.get == end) { return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + sqlContext.sparkContext.emptyRDD[InternalRow].setName("empty"), schema, isStreaming = true) } val fromPartitionOffsets = start match { case Some(prevBatchEndOffset) => @@ -299,7 +299,7 @@ private[kafka010] class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema, isStreaming = true) } /** Stop this source and free any resources it has allocated. */ diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 36b9f0466566b..28c9853bfea9c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,8 +30,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -45,9 +45,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamWriteSupport - with ContinuousReadSupport - with MicroBatchReadSupport + with StreamingWriteSupportProvider + with ContinuousReadSupportProvider + with MicroBatchReadSupportProvider with Logging { import KafkaSourceProvider._ @@ -107,13 +107,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches - * of Kafka data in a micro-batch streaming query. + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport]] to read + * batches of Kafka data in a micro-batch streaming query. */ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( metadataPath: String, - options: DataSourceOptions): KafkaMicroBatchReader = { + options: DataSourceOptions): KafkaMicroBatchReadSupport = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) @@ -139,7 +138,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaMicroBatchReader( + new KafkaMicroBatchReadSupport( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), options, @@ -149,13 +148,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport]] to read * Kafka data in a continuous streaming query. */ - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( metadataPath: String, - options: DataSourceOptions): KafkaContinuousReader = { + options: DataSourceOptions): KafkaContinuousReadSupport = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -180,7 +178,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaContinuousReader( + new KafkaContinuousReadSupport( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), parameters, @@ -269,11 +267,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { + options: DataSourceOptions): StreamingWriteSupport = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get @@ -284,7 +282,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister KafkaWriter.validateQuery( schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) - new KafkaStreamWriter(topic, producerParams, schema) + new KafkaStreamingWriteSupport(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 498e344ea39f4..f8b90056d2931 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -77,44 +77,6 @@ private[kafka010] class KafkaSourceRDD( offsetRanges.zipWithIndex.map { case (o, i) => new KafkaSourceRDDPartition(i, o) }.toArray } - override def count(): Long = offsetRanges.map(_.size).sum - - override def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = { - val c = count - new PartialResult(new BoundedDouble(c, 1.0, c, c), true) - } - - override def isEmpty(): Boolean = count == 0L - - override def take(num: Int): Array[ConsumerRecord[Array[Byte], Array[Byte]]] = { - val nonEmptyPartitions = - this.partitions.map(_.asInstanceOf[KafkaSourceRDDPartition]).filter(_.offsetRange.size > 0) - - if (num < 1 || nonEmptyPartitions.isEmpty) { - return new Array[ConsumerRecord[Array[Byte], Array[Byte]]](0) - } - - // Determine in advance how many messages need to be taken from each partition - val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => - val remain = num - result.values.sum - if (remain > 0) { - val taken = Math.min(remain, part.offsetRange.size) - result + (part.index -> taken.toInt) - } else { - result - } - } - - val buf = new ArrayBuffer[ConsumerRecord[Array[Byte], Array[Byte]]] - val res = context.runJob( - this, - (tc: TaskContext, it: Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]) => - it.take(parts(tc.partitionId)).toArray, parts.keys.toArray - ) - res.foreach(buf ++= _) - buf.toArray - } - override def getPreferredLocations(split: Partition): Seq[String] = { val part = split.asInstanceOf[KafkaSourceRDDPartition] part.offsetRange.preferredLoc.map(Seq(_)).getOrElse(Seq.empty) @@ -124,8 +86,6 @@ private[kafka010] class KafkaSourceRDD( thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] - val topic = sourcePartition.offsetRange.topic - val kafkaPartition = sourcePartition.offsetRange.partition val consumer = KafkaDataConsumer.acquire( sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) @@ -138,6 +98,7 @@ private[kafka010] class KafkaSourceRDD( if (range.fromOffset == range.untilOffset) { logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + s"skipping ${range.topic} ${range.partition}") + consumer.release() Iterator.empty } else { val underlying = new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { @@ -166,7 +127,7 @@ private[kafka010] class KafkaSourceRDD( } } // Release consumer, either by removing it or indicating we're no longer using it - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => underlying.closeIfNeeded() } underlying diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala similarity index 90% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala index ae5b5c52d514e..dc19312f79a22 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.types.StructType /** @@ -33,20 +33,20 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamWriter( +class KafkaStreamingWriteSupport( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamWriter with SupportsWriteInternalRow { + extends StreamingWriteSupport { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + override def createStreamingWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -63,11 +63,11 @@ class KafkaStreamWriter( */ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { + extends StreamingDataWriterFactory { - override def createDataWriter( + override def createWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index d90630a8adc93..041fac7717635 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -110,7 +110,7 @@ private[kafka010] abstract class KafkaRowWriter( case t => throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - "must be a StringType") + s"must be a ${StringType.catalogString}") } val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) .getOrElse(Literal(null, BinaryType)) @@ -118,7 +118,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.catalogString}") } val valueExpression = inputSchema .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( @@ -129,7 +129,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.catalogString}") } UnsafeProjection.create( Seq(topicExpression, Cast(keyExpression, BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..fc09938a43a8c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -57,7 +57,7 @@ private[kafka010] object KafkaWriter extends Logging { ).dataType match { case StringType => // good case _ => - throw new AnalysisException(s"Topic type must be a String") + throw new AnalysisException(s"Topic type must be a ${StringType.catalogString}") } schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( Literal(null, StringType) @@ -65,7 +65,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") } schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") @@ -73,7 +73,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala index 789bffa9da126..0b3355426df10 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala @@ -26,14 +26,13 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.sql.test.SharedSQLContext -class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester { +class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester with KafkaTest { type KP = KafkaProducer[Array[Byte], Array[Byte]] protected override def beforeEach(): Unit = { super.beforeEach() - val clear = PrivateMethod[Unit]('clear) - CachedKafkaProducer.invokePrivate(clear()) + CachedKafkaProducer.clear() } test("Should return the cached instance on calling getOrCreate with same params.") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index ddfc0c1a4be2d..3f6fcf6b2e52c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -40,12 +40,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { override val streamingTimeout = 30.seconds - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils( - withBrokerProps = Map("auto.create.topics.enable" -> "false")) - testUtils.setup() - } + override val brokerProps = Map("auto.create.topics.enable" -> "false") override def afterAll(): Unit = { if (testUtils != null) { @@ -314,7 +309,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { /* key field wrong type */ @@ -330,7 +325,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index aab8ec42189fb..af510219a6f6f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,12 +17,159 @@ package org.apache.spark.sql.kafka010 +import org.apache.kafka.clients.producer.ProducerRecord + import org.apache.spark.sql.Dataset -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. -class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest { + import testImplicits._ + + test("read Kafka transactional messages: read_committed") { + val table = "kafka_continuous_source_test" + withTable(table) { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_committed") + .option("startingOffsets", "earliest") + .option("subscribe", topic) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt) + + val q = df + .writeStream + .format("memory") + .queryName(table) + .trigger(ContinuousTrigger(100)) + .start() + try { + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // Should not read any messages before they are committed + assert(spark.table(table).isEmpty) + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all committed messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // Should not read aborted messages + checkAnswer(spark.table(table), (1 to 5).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should skip aborted messages and read new committed ones. + checkAnswer(spark.table(table), ((1 to 5) ++ (11 to 15)).toDF) + } + } finally { + q.stop() + } + } + } + } + + test("read Kafka transactional messages: read_uncommitted") { + val table = "kafka_continuous_source_test" + withTable(table) { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_uncommitted") + .option("startingOffsets", "earliest") + .option("subscribe", topic) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt) + + val q = df + .writeStream + .format("memory") + .queryName(table) + .trigger(ContinuousTrigger(100)) + .start() + try { + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + eventually(timeout(streamingTimeout)) { + // Should read uncommitted messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all committed messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read aborted messages + checkAnswer(spark.table(table), (1 to 10).toDF) + } + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + eventually(timeout(streamingTimeout)) { + // Should read all messages including committed, aborted and uncommitted messages + checkAnswer(spark.table(table), (1 to 15).toDF) + } + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all messages including committed and aborted messages + checkAnswer(spark.table(table), (1 to 15).toDF) + } + } finally { + q.stop() + } + } + } + } +} class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { import testImplicits._ @@ -42,6 +189,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribePattern", s"$topicPrefix-.*") .option("failOnDataLoss", "false") @@ -59,11 +207,13 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { testUtils.createTopic(topic2, partitions = 5) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r - }.exists { r => + query.lastExecution.executedPlan.collectFirst { + case scan: DataSourceV2ScanExec + if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] + }.exists { config => // Ensure the new topic is present and the old topic is gone. - r.knownPartitions.exists(_.topic == topic2) + config.knownPartitions.exists(_.topic == topic2) }, s"query never reconfigured to new topic $topic2") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa1468a3943c8..fa6bdc20bd4f9 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -46,8 +46,10 @@ trait KafkaContinuousTest extends KafkaSourceTest { testUtils.addPartitions(topic, newCount) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r + query.lastExecution.executedPlan.collectFirst { + case scan: DataSourceV2ScanExec + if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala new file mode 100644 index 0000000000000..39c4e3fda1a4b --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.util.Random + +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter} +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +/** + * This is a basic test trait which will set up a Kafka cluster that keeps only several records in + * a topic and ages out records very quickly. This is a helper trait to test + * "failonDataLoss=false" case with missing offsets. + * + * Note: there is a hard-code 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) to clean up + * records. Hence each class extending this trait needs to wait at least 30 seconds (or even longer + * when running on a slow Jenkins machine) before records start to be removed. To make sure a test + * does see missing offsets, you can check the earliest offset in `eventually` and make sure it's + * not 0 rather than sleeping a hard-code duration. + */ +trait KafkaMissingOffsetsTest extends SharedSQLContext { + + protected var testUtils: KafkaTestUtils = _ + + override def createSparkSession(): TestSparkSession = { + // Set maxRetries to 3 to handle NPE from `poll` when deleting a topic + new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context", sparkConf)) + } + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils { + override def brokerConfiguration: Properties = { + val props = super.brokerConfiguration + // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code + // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at + // least 30 seconds. + props.put("log.cleaner.backoff.ms", "100") + // The size of RecordBatch V2 increases to support transactional write. + props.put("log.segment.bytes", "70") + props.put("log.retention.bytes", "40") + props.put("log.retention.check.interval.ms", "100") + props.put("delete.retention.ms", "10") + props.put("log.flush.scheduler.interval.ms", "10") + props + } + } + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } +} + +class KafkaDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + + /** + * @param testStreamingQuery whether to test a streaming query or a batch query. + * @param writeToTable the function to write the specified [[DataFrame]] to the given table. + */ + private def verifyMissingOffsetsDontCauseDuplicatedRecords( + testStreamingQuery: Boolean)(writeToTable: (DataFrame, String) => Unit): Unit = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (0 until 50).map(_.toString).toArray) + + eventually(timeout(60.seconds)) { + assert( + testUtils.getEarliestOffsets(Set(topic)).head._2 > 0, + "Kafka didn't delete records after 1 minute") + } + + val table = "DontFailOnDataLoss" + withTable(table) { + val kafkaOptions = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "kafka.metadata.max.age.ms" -> "1", + "subscribe" -> topic, + "startingOffsets" -> s"""{"$topic":{"0":0}}""", + "failOnDataLoss" -> "false", + "kafkaConsumer.pollTimeoutMs" -> "1000") + val df = + if (testStreamingQuery) { + val reader = spark.readStream.format("kafka") + kafkaOptions.foreach(kv => reader.option(kv._1, kv._2)) + reader.load() + } else { + val reader = spark.read.format("kafka") + kafkaOptions.foreach(kv => reader.option(kv._1, kv._2)) + reader.load() + } + writeToTable(df.selectExpr("CAST(value AS STRING)"), table) + val result = spark.table(table).as[String].collect().toList + assert(result.distinct.size === result.size, s"$result contains duplicated records") + // Make sure Kafka did remove some records so that this test is valid. + assert(result.size > 0 && result.size < 50) + } + } + + test("failOnDataLoss=false should not return duplicated records: v1") { + withSQLConf( + "spark.sql.streaming.disabledV2MicroBatchReaders" -> + classOf[KafkaSourceProvider].getCanonicalName) { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream.format("memory").queryName(table).start() + try { + query.processAllAvailable() + } finally { + query.stop() + } + } + } + } + + test("failOnDataLoss=false should not return duplicated records: v2") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream.format("memory").queryName(table).start() + try { + query.processAllAvailable() + } finally { + query.stop() + } + } + } + + test("failOnDataLoss=false should not return duplicated records: continuous processing") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream + .format("memory") + .queryName(table) + .trigger(Trigger.Continuous(100)) + .start() + try { + // `processAllAvailable` doesn't work for continuous processing, so just wait until the last + // record appears in the table. + eventually(timeout(streamingTimeout)) { + assert(spark.table(table).as[String].collect().contains("49")) + } + } finally { + query.stop() + } + } + } + + test("failOnDataLoss=false should not return duplicated records: batch") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = false) { (df, table) => + df.write.saveAsTable(table) + } + } +} + +class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = true + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = {} + }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) + + val testTime = 1.minutes + val startTime = System.currentTimeMillis() + // Track the current existing topics + val topics = mutable.ArrayBuffer[String]() + // Track topics that have been deleted + val deletedTopics = mutable.Set[String]() + while (System.currentTimeMillis() - testTime.toMillis < startTime) { + Random.nextInt(10) match { + case 0 => // Create a new topic + val topic = newTopic() + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 1 if topics.nonEmpty => // Delete an existing topic + val topic = topics.remove(Random.nextInt(topics.size)) + testUtils.deleteTopic(topic) + logInfo(s"Delete topic $topic") + deletedTopics += topic + case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted. + val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size)) + deletedTopics -= topic + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 3 => + Thread.sleep(1000) + case _ => // Push random messages + for (topic <- topics) { + val size = Random.nextInt(10) + for (_ <- 0 until size) { + testUtils.sendMessages(topic, Array(Random.nextInt(10).toString)) + } + } + } + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } + + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e017fd9b84d21..eb66ccac744a3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,36 +20,33 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Optional, Properties} +import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.io.Source import scala.util.Random -import org.apache.kafka.clients.producer.RecordMetadata +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, RecordMetadata} import org.apache.kafka.common.TopicPartition +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkContext -import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.test.SharedSQLContext -abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { +abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { protected var testUtils: KafkaTestUtils = _ @@ -117,14 +114,16 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = { + val sources: Seq[BaseStreamingSource] = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source - case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReadSupport, _) => source } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader + case r: StreamingDataSourceV2Relation + if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + r.readSupport.asInstanceOf[KafkaContinuousReadSupport] } }) }.distinct @@ -160,6 +159,19 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { s"AddKafkaData(topics = $topics, data = $data, message = $message)" } + object WithOffsetSync { + def apply(topic: String)(func: () => Unit): StreamAction = { + Execute("Run Kafka Producer")(_ => { + func() + // This is a hack for the race condition that the committed message may be not visible to + // consumer for a short time. + // Looks like after the following call returns, the consumer can always read the committed + // messages. + testUtils.getLatestOffsets(Set(topic)) + }) + } + } + private val topicId = new AtomicInteger(0) protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } @@ -290,6 +302,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribePattern", s"$topicPrefix-.*") .option("failOnDataLoss", "false") @@ -467,6 +480,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribe", topic) // If a topic is deleted and we try to poll data starting from offset 0, // the Kafka consumer will just block until timeout and return an empty result. @@ -563,7 +577,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } - test("ensure stream-stream self-join generates only one offset in offset log") { + test("ensure stream-stream self-join generates only one offset in log and correct metrics") { val topic = newTopic() testUtils.createTopic(topic, partitions = 2) require(testUtils.getLatestOffsets(Set(topic)).size === 2) @@ -587,9 +601,254 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AddKafkaData(Set(topic), 1, 2), CheckAnswer((1, 1, 1), (2, 2, 2)), AddKafkaData(Set(topic), 6, 3), - CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.recentProgress.map(_.numInputRows).sum == 4) + true + } ) } + + test("read Kafka transactional messages: read_committed") { + // This test will cover the following cases: + // 1. the whole batch contains no data messages + // 2. the first offset in a batch is not a committed data message + // 3. the last offset in a batch is not a committed data message + // 4. there is a gap in the middle of a batch + + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.isolation.level", "read_committed") + .option("maxOffsetsPerTrigger", 3) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + // Set a short timeout to make the test fast. When a batch doesn't contain any visible data + // messages, "poll" will wait until timeout. + .option("kafkaConsumer.pollTimeoutMs", 5000) + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + // Wait until the manual clock is waiting on further instructions to move forward. Then we can + // ensure all batches we are waiting for have been processed. + val waitUntilBatchProcessed = Execute { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + } + + // The message values are the same as their offsets to make the test easy to follow + testUtils.withTranscationalProducer { producer => + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + CheckAnswer(), + WithOffsetSync(topic) { () => + // Send 5 messages. They should be visible only after being committed. + producer.beginTransaction() + (0 to 4).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + // Should not see any uncommitted messages + CheckNewAnswer(), + WithOffsetSync(topic) { () => + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(0, 1, 2), // offset 0, 1, 2 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(3, 4), // offset: 3, 4, 5* [* means it's not a committed data message] + WithOffsetSync(topic) { () => + // Send 5 messages and abort the transaction. They should not be read. + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(), // offset: 6*, 7*, 8* + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(), // offset: 9*, 10*, 11* + WithOffsetSync(topic) { () => + // Send 5 messages again. The consumer should skip the above aborted messages and read + // them. + producer.beginTransaction() + (12 to 16).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(12, 13, 14), // offset: 12, 13, 14 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(15, 16), // offset: 15, 16, 17* + WithOffsetSync(topic) { () => + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "18")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "20")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "22")).get() + producer.send(new ProducerRecord[String, String](topic, "23")).get() + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(18, 20), // offset: 18, 19*, 20 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(22, 23), // offset: 21*, 22, 23 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer() // offset: 24* + ) + } + } + + test("read Kafka transactional messages: read_uncommitted") { + // This test will cover the following cases: + // 1. the whole batch contains no data messages + // 2. the first offset in a batch is not a committed data message + // 3. the last offset in a batch is not a committed data message + // 4. there is a gap in the middle of a batch + + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.isolation.level", "read_uncommitted") + .option("maxOffsetsPerTrigger", 3) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + // Set a short timeout to make the test fast. When a batch doesn't contain any visible data + // messages, "poll" will wait until timeout. + .option("kafkaConsumer.pollTimeoutMs", 5000) + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + // Wait until the manual clock is waiting on further instructions to move forward. Then we can + // ensure all batches we are waiting for have been processed. + val waitUntilBatchProcessed = Execute { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + } + + // The message values are the same as their offsets to make the test easy to follow + testUtils.withTranscationalProducer { producer => + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + CheckNewAnswer(), + WithOffsetSync(topic) { () => + // Send 5 messages. They should be visible only after being committed. + producer.beginTransaction() + (0 to 4).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(0, 1, 2), // offset 0, 1, 2 + WithOffsetSync(topic) { () => + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(3, 4), // offset: 3, 4, 5* [* means it's not a committed data message] + WithOffsetSync(topic) { () => + // Send 5 messages and abort the transaction. They should not be read. + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(6, 7, 8), // offset: 6, 7, 8 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(9, 10), // offset: 9, 10, 11* + WithOffsetSync(topic) { () => + // Send 5 messages again. The consumer should skip the above aborted messages and read + // them. + producer.beginTransaction() + (12 to 16).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(12, 13, 14), // offset: 12, 13, 14 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(15, 16), // offset: 15, 16, 17* + WithOffsetSync(topic) { () => + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "18")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "20")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "22")).get() + producer.send(new ProducerRecord[String, String](topic, "23")).get() + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(18, 20), // offset: 18, 19*, 20 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(22, 23), // offset: 21*, 22, 23 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer() // offset: 24* + ) + } + } } @@ -642,7 +901,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { makeSureGetOffsetCalled, AssertOnQuery { query => query.logicalPlan.collect { - case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + case StreamingExecutionRelation(_: KafkaMicroBatchReadSupport, _) => true }.nonEmpty } ) @@ -667,17 +926,16 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val reader = provider.createMicroBatchReader( - Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) - reader.setOffsetRange( - Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), - Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) - ) - val factories = reader.createUnsafeRowReaderFactories().asScala - .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) - withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { - assert(factories.size == numPartitionsGenerated) - factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + val readSupport = provider.createMicroBatchReadSupport( + dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + val config = readSupport.newScanConfigBuilder( + KafkaSourceOffset(Map(tp -> 0L)), + KafkaSourceOffset(Map(tp -> 100L))).build() + val inputPartitions = readSupport.planInputPartitions(config) + .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) + withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") { + assert(inputPartitions.size == numPartitionsGenerated) + inputPartitions.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } } } } @@ -694,6 +952,41 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } } + test("custom lag metrics") { + import testImplicits._ + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + testUtils.sendMessages(topic, (1 to 100).map(_.toString).toArray) + require(testUtils.getLatestOffsets(Set(topic)).size === 2) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .option("maxOffsetsPerTrigger", 10) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + implicit val formats = DefaultFormats + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = OneTimeTrigger), + AssertOnQuery { query => + query.awaitTermination() + val source = query.lastProgress.sources(0) + // masOffsetsPerTrigger is 10, and there are two partitions containing 50 events each + // so 5 events should be processed from each partition and a lag of 45 events + val custom = parse(source.customMetrics) + .extract[Map[String, Map[String, Map[String, Long]]]] + custom("lag")(topic)("0") == 45 && custom("lag")(topic)("1") == 45 + } + ) + } + } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { @@ -928,7 +1221,8 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { makeSureGetOffsetCalled, Execute { q => // wait to reach the last offset in every partition - q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + q.awaitOffset( + 0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)), streamingTimeout.toMillis) }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, @@ -1098,6 +1392,7 @@ class KafkaSourceStressSuite extends KafkaSourceTest { .option("kafka.metadata.max.age.ms", "1") .option("subscribePattern", "stress.*") .option("failOnDataLoss", "false") + .option("kafka.default.api.timeout.ms", "3000") .load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] @@ -1143,132 +1438,3 @@ class KafkaSourceStressSuite extends KafkaSourceTest { iterations = 50) } } - -class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with SharedSQLContext { - - import testImplicits._ - - private var testUtils: KafkaTestUtils = _ - - private val topicId = new AtomicInteger(0) - - private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" - - override def createSparkSession(): TestSparkSession = { - // Set maxRetries to 3 to handle NPE from `poll` when deleting a topic - new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context", sparkConf)) - } - - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils { - override def brokerConfiguration: Properties = { - val props = super.brokerConfiguration - // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code - // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at - // least 30 seconds. - props.put("log.cleaner.backoff.ms", "100") - props.put("log.segment.bytes", "40") - props.put("log.retention.bytes", "40") - props.put("log.retention.check.interval.ms", "100") - props.put("delete.retention.ms", "10") - props.put("log.flush.scheduler.interval.ms", "10") - props - } - } - testUtils.setup() - } - - override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null - super.afterAll() - } - } - - protected def startStream(ds: Dataset[Int]) = { - ds.writeStream.foreach(new ForeachWriter[Int] { - - override def open(partitionId: Long, version: Long): Boolean = { - true - } - - override def process(value: Int): Unit = { - // Slow down the processing speed so that messages may be aged out. - Thread.sleep(Random.nextInt(500)) - } - - override def close(errorOrNull: Throwable): Unit = { - } - }).start() - } - - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = startStream(kafka.map(kv => kv._2.toInt)) - - val testTime = 1.minutes - val startTime = System.currentTimeMillis() - // Track the current existing topics - val topics = mutable.ArrayBuffer[String]() - // Track topics that have been deleted - val deletedTopics = mutable.Set[String]() - while (System.currentTimeMillis() - testTime.toMillis < startTime) { - Random.nextInt(10) match { - case 0 => // Create a new topic - val topic = newTopic() - topics += topic - // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small - // chance that a topic will be recreated after deletion due to the asynchronous update. - // Hence, always overwrite to handle this race condition. - testUtils.createTopic(topic, partitions = 1, overwrite = true) - logInfo(s"Create topic $topic") - case 1 if topics.nonEmpty => // Delete an existing topic - val topic = topics.remove(Random.nextInt(topics.size)) - testUtils.deleteTopic(topic) - logInfo(s"Delete topic $topic") - deletedTopics += topic - case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted. - val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size)) - deletedTopics -= topic - topics += topic - // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small - // chance that a topic will be recreated after deletion due to the asynchronous update. - // Hence, always overwrite to handle this race condition. - testUtils.createTopic(topic, partitions = 1, overwrite = true) - logInfo(s"Create topic $topic") - case 3 => - Thread.sleep(1000) - case _ => // Push random messages - for (topic <- topics) { - val size = Random.nextInt(10) - for (_ <- 0 until size) { - testUtils.sendMessages(topic, Array(Random.nextInt(10).toString)) - } - } - } - // `failOnDataLoss` is `false`, we should not fail the query - if (query.exception.nonEmpty) { - throw query.exception.get - } - } - - query.stop() - // `failOnDataLoss` is `false`, we should not fail the query - if (query.exception.nonEmpty) { - throw query.exception.get - } - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 91893df4ec32f..93dba18446280 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.kafka010 import java.util.Locale import java.util.concurrent.atomic.AtomicInteger +import org.apache.kafka.clients.producer.ProducerRecord import org.apache.kafka.common.TopicPartition -import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { +class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest { import testImplicits._ @@ -235,4 +235,96 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon testBadOptions("subscribe" -> "")("no topics to subscribe") testBadOptions("subscribePattern" -> "")("pattern to subscribe is empty") } + + test("read Kafka transactional messages: read_committed") { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_committed") + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // Should not read any messages before they are committed + assert(df.isEmpty) + + producer.commitTransaction() + + // Should read all committed messages + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // Should not read aborted messages + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + // Should skip aborted messages and read new committed ones. + checkAnswer(df, ((1 to 5) ++ (11 to 15)).map(_.toString).toDF) + } + } + + test("read Kafka transactional messages: read_uncommitted") { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_uncommitted") + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // "read_uncommitted" should see all messages including uncommitted ones + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.commitTransaction() + + // Should read all committed messages + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // "read_uncommitted" should see all messages including uncommitted or aborted ones + checkAnswer(df, (1 to 10).map(_.toString).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + // Should read all messages + checkAnswer(df, (1 to 15).map(_.toString).toDF) + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 7079ac6453ffc..a2213e024bd98 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, DataType} -class KafkaSinkSuite extends StreamTest with SharedSQLContext { +class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { import testImplicits._ protected var testUtils: KafkaTestUtils = _ @@ -303,7 +303,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { ex = intercept[StreamingQueryException] { @@ -318,7 +318,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala new file mode 100644 index 0000000000000..19acda95c707c --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite + +/** A trait to clean cached Kafka producers in `afterAll` */ +trait KafkaTest extends BeforeAndAfterAll { + self: SparkFunSuite => + + override def afterAll(): Unit = { + super.afterAll() + CachedKafkaProducer.clear() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 75245943c4936..7b742a3ea6741 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io.{File, IOException} import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.{Map => JMap, Properties} +import java.util.{Map => JMap, Properties, UUID} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -29,20 +29,23 @@ import scala.util.Random import kafka.admin.AdminUtils import kafka.api.Request -import kafka.common.TopicAndPartition -import kafka.server.{KafkaConfig, KafkaServer, OffsetCheckpoint} +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.server.checkpoints.OffsetCheckpointFile import kafka.utils.ZkUtils +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.{AdminClient, CreatePartitionsOptions, NewPartitions} import org.apache.kafka.clients.consumer.KafkaConsumer import org.apache.kafka.clients.producer._ import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -53,17 +56,18 @@ import org.apache.spark.util.Utils class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends Logging { // Zookeeper related configurations - private val zkHost = "localhost" + private val zkHost = "127.0.0.1" private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 - private val zkSessionTimeout = 6000 + private val zkSessionTimeout = 10000 private var zookeeper: EmbeddedZookeeper = _ private var zkUtils: ZkUtils = _ + private var adminClient: AdminClient = null // Kafka broker related configurations - private val brokerHost = "localhost" + private val brokerHost = "127.0.0.1" private var brokerPort = 0 private var brokerConf: KafkaConfig = _ @@ -76,6 +80,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L // Flag to test whether the system is correctly started private var zkReady = false private var brokerReady = false + private var leakDetector: AnyRef = null def zkAddress: String = { assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") @@ -113,21 +118,37 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) server = new KafkaServer(brokerConf) server.startup() - brokerPort = server.boundPort() + brokerPort = server.boundPort(new ListenerName("PLAINTEXT")) (server, brokerPort) }, new SparkConf(), "KafkaBroker") brokerReady = true + val props = new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, s"$brokerHost:$brokerPort") + adminClient = AdminClient.create(props) } /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ def setup(): Unit = { + // Set up a KafkaTestUtils leak detector so that we can see where the leak KafkaTestUtils is + // created. + val exception = new SparkException("It was created at: ") + leakDetector = ShutdownHookManager.addShutdownHook { () => + logError("Found a leak KafkaTestUtils.", exception) + } + setupEmbeddedZookeeper() setupEmbeddedKafkaServer() + eventually(timeout(60.seconds)) { + assert(zkUtils.getAllBrokersInCluster().nonEmpty, "Broker was not up in 60 seconds") + } } /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { + if (leakDetector != null) { + ShutdownHookManager.removeShutdownHook(leakDetector) + } brokerReady = false zkReady = false @@ -136,6 +157,10 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L producer = null } + if (adminClient != null) { + adminClient.close() + } + if (server != null) { server.shutdown() server.awaitShutdown() @@ -203,7 +228,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L /** Add new partitions to a Kafka topic */ def addPartitions(topic: String, partitions: Int): Unit = { - AdminUtils.addPartitions(zkUtils, topic, partitions) + adminClient.createPartitions( + Map(topic -> NewPartitions.increaseTo(partitions)).asJava, + new CreatePartitionsOptions) // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) @@ -287,15 +314,23 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L protected def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") - props.put("host.name", "localhost") - props.put("advertised.host.name", "localhost") + props.put("host.name", "127.0.0.1") + props.put("advertised.host.name", "127.0.0.1") props.put("port", brokerPort.toString) props.put("log.dir", Utils.createTempDir().getAbsolutePath) props.put("zookeeper.connect", zkAddress) + props.put("zookeeper.connection.timeout.ms", "60000") props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") + props.put("group.initial.rebalance.delay.ms", "10") + + // Change the following settings as we have only 1 broker props.put("offsets.topic.num.partitions", "1") + props.put("offsets.topic.replication.factor", "1") + props.put("transaction.state.log.replication.factor", "1") + props.put("transaction.state.log.min.isr", "1") + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 withBrokerProps.foreach { case (k, v) => props.put(k, v) } @@ -312,6 +347,19 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props } + /** Call `f` with a `KafkaProducer` that has initialized transactions. */ + def withTranscationalProducer(f: KafkaProducer[String, String] => Unit): Unit = { + val props = producerConfiguration + props.put("transactional.id", UUID.randomUUID().toString) + val producer = new KafkaProducer[String, String](props) + try { + producer.initTransactions() + f(producer) + } finally { + producer.close() + } + } + private def consumerConfiguration: Properties = { val props = new Properties() props.put("bootstrap.servers", brokerAddress) @@ -327,7 +375,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L topic: String, numPartitions: Int, servers: Seq[KafkaServer]): Unit = { - val topicAndPartitions = (0 until numPartitions).map(TopicAndPartition(topic, _)) + val topicAndPartitions = (0 until numPartitions).map(new TopicPartition(topic, _)) import ZkUtils._ // wait until admin path for delete topic is deleted, signaling completion of topic deletion @@ -337,7 +385,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L assert(!zkUtils.pathExists(getTopicPath(topic)), s"${getTopicPath(topic)} still exists") // ensure that the topic-partition has been deleted from all brokers' replica managers assert(servers.forall(server => topicAndPartitions.forall(tp => - server.replicaManager.getPartition(tp.topic, tp.partition) == None)), + server.replicaManager.getPartition(tp) == None)), s"topic $topic still exists in the replica manager") // ensure that logs from all replicas are deleted if delete topic is marked successful assert(servers.forall(server => topicAndPartitions.forall(tp => @@ -345,8 +393,8 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L s"topic $topic still exists in log mananger") // ensure that topic is removed from all cleaner offsets assert(servers.forall(server => topicAndPartitions.forall { tp => - val checkpoints = server.getLogManager().logDirs.map { logDir => - new OffsetCheckpoint(new File(logDir, "cleaner-offset-checkpoint")).read() + val checkpoints = server.getLogManager().liveLogDirs.map { logDir => + new OffsetCheckpointFile(new File(logDir, "cleaner-offset-checkpoint")).read() } checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) }), s"checkpoint for topic $topic still exists") @@ -379,11 +427,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => - val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr - zkUtils.getLeaderForPartition(topic, partition).isDefined && - Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && - leaderAndInSyncReplicas.isr.nonEmpty + Request.isValidBrokerId(partitionState.basePartitionState.leader) && + !partitionState.basePartitionState.replicas.isEmpty case _ => false diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 3b124b2a69d50..a97fd35bfbb73 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -28,7 +28,8 @@ spark-streaming-kafka-0-10_2.11 streaming-kafka-0-10 - 0.10.0.1 + + 2.0.0 jar Spark Integration for Kafka 0.10 @@ -58,6 +59,20 @@ kafka_${scala.binary.version} ${kafka.version} test + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + net.sf.jopt-simple @@ -93,13 +108,4 @@ target/scala-${scala.binary.version}/test-classes - - - scala-2.12 - - 0.10.1.1 - - - - diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala deleted file mode 100644 index aeb8c1dc342b3..0000000000000 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.kafka010 - -import java.{ util => ju } - -import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } -import org.apache.kafka.common.{ KafkaException, TopicPartition } - -import org.apache.spark.internal.Logging - -/** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. - */ -private[kafka010] -class CachedKafkaConsumer[K, V] private( - val groupId: String, - val topic: String, - val partition: Int, - val kafkaParams: ju.Map[String, Object]) extends Logging { - - require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), - "groupId used for cache key must match the groupId in kafkaParams") - - val topicPartition = new TopicPartition(topic, partition) - - protected val consumer = { - val c = new KafkaConsumer[K, V](kafkaParams) - val tps = new ju.ArrayList[TopicPartition]() - tps.add(topicPartition) - c.assign(tps) - c - } - - // TODO if the buffer was kept around as a random-access structure, - // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() - protected var nextOffset = -2L - - def close(): Unit = consumer.close() - - /** - * Get the record for the given offset, waiting up to timeout ms if IO is necessary. - * Sequential forward access will use buffers, but random access will be horribly inefficient. - */ - def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { - logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset") - if (offset != nextOffset) { - logInfo(s"Initial fetch for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - - if (!buffer.hasNext()) { poll(timeout) } - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - var record = buffer.next() - - if (record.offset != offset) { - logInfo(s"Buffer miss for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - record = buffer.next() - require(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + - s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + - "spark.streaming.kafka.allowNonConsecutiveOffsets" - ) - } - - nextOffset = offset + 1 - record - } - - /** - * Start a batch on a compacted topic - */ - def compactedStart(offset: Long, timeout: Long): Unit = { - logDebug(s"compacted start $groupId $topic $partition starting $offset") - // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics - if (offset != nextOffset) { - logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - } - - /** - * Get the next record in the batch from a compacted topic. - * Assumes compactedStart has been called first, and ignores gaps. - */ - def compactedNext(timeout: Long): ConsumerRecord[K, V] = { - if (!buffer.hasNext()) { - poll(timeout) - } - require(buffer.hasNext(), - s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") - val record = buffer.next() - nextOffset = record.offset + 1 - record - } - - /** - * Rewind to previous record in the batch from a compacted topic. - * @throws NoSuchElementException if no previous element - */ - def compactedPrevious(): ConsumerRecord[K, V] = { - buffer.previous() - } - - private def seek(offset: Long): Unit = { - logDebug(s"Seeking to $topicPartition $offset") - consumer.seek(topicPartition, offset) - } - - private def poll(timeout: Long): Unit = { - val p = consumer.poll(timeout) - val r = p.records(topicPartition) - logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.listIterator - } - -} - -private[kafka010] -object CachedKafkaConsumer extends Logging { - - private case class CacheKey(groupId: String, topic: String, partition: Int) - - // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap - private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null - - /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */ - def init( - initialCapacity: Int, - maxCapacity: Int, - loadFactor: Float): Unit = CachedKafkaConsumer.synchronized { - if (null == cache) { - logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") - cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]]( - initialCapacity, loadFactor, true) { - override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = { - if (this.size > maxCapacity) { - try { - entry.getValue.consumer.close() - } catch { - case x: KafkaException => - logError("Error closing oldest Kafka consumer", x) - } - true - } else { - false - } - } - } - } - } - - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def get[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - CachedKafkaConsumer.synchronized { - val k = CacheKey(groupId, topic, partition) - val v = cache.get(k) - if (null == v) { - logInfo(s"Cache miss for $k") - logDebug(cache.keySet.toString) - val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - cache.put(k, c) - c - } else { - // any given topicpartition should have a consistent key and value type - v.asInstanceOf[CachedKafkaConsumer[K, V]] - } - } - - /** - * Get a fresh new instance, unassociated with the global cache. - * Caller is responsible for closing - */ - def getUncached[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - - /** remove consumer for given groupId, topic, and partition, if it exists */ - def remove(groupId: String, topic: String, partition: Int): Unit = { - val k = CacheKey(groupId, topic, partition) - logInfo(s"Removing $k from cache") - val v = CachedKafkaConsumer.synchronized { - cache.remove(k) - } - if (null != v) { - v.close() - logInfo(s"Removed $k from cache") - } - } -} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index c3221481556f5..0246006acf0bd 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -166,6 +166,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( * which would throw off consumer position. Fix position if this happens. */ private def paranoidPoll(c: Consumer[K, V]): Unit = { + // don't actually want to consume any messages, so pause all partitions + c.pause(c.assignment()) val msgs = c.poll(0) if (!msgs.isEmpty) { // position should be minimum offset per topicpartition @@ -204,8 +206,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap - // don't want to consume messages, so pause - c.pause(newPartitions.asJava) // find latest available offsets c.seekToEnd(currentOffsets.keySet.asJava) parts.map(tp => tp -> c.position(tp)).toMap @@ -262,9 +262,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( tp -> c.position(tp) }.toMap } - - // don't actually want to consume any messages, so pause all partitions - c.pause(currentOffsets.keySet.asJava) } override def stop(): Unit = this.synchronized { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala new file mode 100644 index 0000000000000..68c5fe9ab066a --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.{KafkaException, TopicPartition} + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging + +private[kafka010] sealed trait KafkaDataConsumer[K, V] { + /** + * Get the record for the given offset if available. + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.get(offset, pollTimeoutMs) + } + + /** + * Start a batch on a compacted topic + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + internalConsumer.compactedStart(offset, pollTimeoutMs) + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + * + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.compactedNext(pollTimeoutMs) + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + internalConsumer.compactedPrevious() + } + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + def internalConsumer: InternalKafkaConsumer[K, V] +} + + +/** + * A wrapper around Kafka's KafkaConsumer. + * This is not for direct use outside this file. + */ +private[kafka010] class InternalKafkaConsumer[K, V]( + val topicPartition: TopicPartition, + val kafkaParams: ju.Map[String, Object]) extends Logging { + + private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG) + .asInstanceOf[String] + + private val consumer = createConsumer + + /** indicates whether this consumer is in use or not */ + var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + var markedForClose = false + + // TODO if the buffer was kept around as a random-access structure, + // could possibly optimize re-calculating of an RDD in the same batch + @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() + @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET + + override def toString: String = { + "InternalKafkaConsumer(" + + s"hash=${Integer.toHexString(hashCode)}, " + + s"groupId=$groupId, " + + s"topicPartition=$topicPartition)" + } + + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[K, V] = { + val c = new KafkaConsumer[K, V](kafkaParams) + val topics = ju.Arrays.asList(topicPartition) + c.assign(topics) + c + } + + def close(): Unit = consumer.close() + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset") + if (offset != nextOffset) { + logInfo(s"Initial fetch for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + } + + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + var record = buffer.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + record = buffer.next() + require(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) + } + + nextOffset = offset + 1 + record + } + + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + logDebug(s"compacted start $groupId $topicPartition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(pollTimeoutMs) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topicPartition " + + s"after polling for $pollTimeoutMs") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(timeout: Long): Unit = { + val p = consumer.poll(timeout) + val r = p.records(topicPartition) + logDebug(s"Polled ${p.partitions()} ${r.size}") + buffer = r.listIterator + } + +} + +private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) + +private[kafka010] object KafkaDataConsumer extends Logging { + + private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + assert(internalConsumer.inUse) + override def release(): Unit = KafkaDataConsumer.release(internalConsumer) + } + + private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + override def release(): Unit = internalConsumer.close() + } + + // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap + private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null + + /** + * Must be called before acquire, once per JVM, to configure the cache. + * Further calls are ignored. + */ + def init( + initialCapacity: Int, + maxCapacity: Int, + loadFactor: Float): Unit = synchronized { + if (null == cache) { + logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") + cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]]( + initialCapacity, loadFactor, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > maxCapacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $maxCapacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case x: KafkaException => + logError("Error closing oldest Kafka consumer", x) + } + true + } else { + false + } + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by anyone + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. + */ + def acquire[K, V]( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + context: TaskContext, + useCache: Boolean): KafkaDataConsumer[K, V] = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = cache.get(key) + + lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams) + + if (context != null && context.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumers if any and + // start with a new one. If prior attempt failures were cache related then this way old + // problematic consumers can be removed. + logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer") + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + // Remove the consumer from cache only if it's closed. + // Marked for close consumers will be removed in release function. + cache.remove(key) + } + } + + logDebug("Reattempt detected, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (!useCache) { + // If consumer reuse turned off, then do not use it, return a new consumer + logDebug("Cache usage turned off, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + logDebug("No cached consumer, new cached consumer will be allocated " + + s"$newInternalConsumer") + cache.put(key, newInternalConsumer) + CachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + logDebug("Used cached consumer found, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else { + // If consumer is already cached and is currently not in use, then return that consumer + logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer") + existingInternalConsumer.inUse = true + // Any given TopicPartition should have a consistent key and value type + CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]]) + } + } + + private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized { + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition) + val cachedInternalConsumer = cache.get(key) + if (internalConsumer.eq(cachedInternalConsumer)) { + // The released consumer is the same object as the cached one. + if (internalConsumer.markedForClose) { + internalConsumer.close() + cache.remove(key) + } else { + internalConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + internalConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache " + + s"$internalConsumer") + } + } +} + +private[kafka010] object InternalKafkaConsumer { + private val UNKNOWN_OFFSET = -2L +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 07239eda64d2e..4513dca44c7c6 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } -import scala.collection.mutable.ArrayBuffer - import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord } import org.apache.kafka.common.TopicPartition @@ -67,7 +65,7 @@ private[spark] class KafkaRDD[K, V]( // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time? private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", - conf.getTimeAsMs("spark.network.timeout", "120s")) + conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val cacheInitialCapacity = conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16) private val cacheMaxCapacity = @@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - - context.addTaskCompletionListener(_ => closeIfNeeded()) + context.addTaskCompletionListener[Unit](_ => closeIfNeeded()) - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + val consumer = { + KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache) } var requestOffset = part.fromOffset def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close() + if (consumer != null) { + consumer.release() } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..d934c64962adb --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ + +class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll { + private var testUtils: KafkaTestUtils = _ + private val topic = "topic" + Random.nextInt() + private val topicPartition = new TopicPartition(topic, 0) + private val groupId = "groupId" + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + KafkaDataConsumer.init(16, 64, 0.75f) + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + private def getKafkaParams() = Map[String, Object]( + GROUP_ID_CONFIG -> groupId, + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ).asJava + + test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") { + KafkaDataConsumer.cache.clear() + + val kafkaParams = getKafkaParams() + + val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer1.release() + + val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer2.release() + + assert(KafkaDataConsumer.cache.size() == 1) + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = KafkaDataConsumer.cache.get(key) + assert(existingInternalConsumer.eq(consumer1.internalConsumer)) + assert(existingInternalConsumer.eq(consumer2.internalConsumer)) + } + + test("concurrent use of KafkaDataConsumer") { + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic) + testUtils.sendMessages(topic, data.toArray) + + val kafkaParams = getKafkaParams() + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, taskContext, useCache) + try { + val rcvd = (0 until data.length).map { offset => + val bytes = consumer.get(offset, 10000).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadPool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadPool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadPool.shutdown() + } + } +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index 271adea1df731..3ac6509b04707 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -23,11 +23,11 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Random -import kafka.common.TopicAndPartition -import kafka.log._ -import kafka.message._ +import kafka.log.{CleanerConfig, Log, LogCleaner, LogConfig, ProducerStateManager} +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} import kafka.utils.Pool import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll @@ -72,33 +72,39 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { val mockTime = new MockTime() - // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api - val logs = new Pool[TopicAndPartition, Log]() + val logs = new Pool[TopicPartition, Log]() val logDir = kafkaTestUtils.brokerLogDir val dir = new File(logDir, topic + "-" + partition) dir.mkdirs() val logProps = new ju.Properties() logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) + val logDirFailureChannel = new LogDirFailureChannel(1) + val topicPartition = new TopicPartition(topic, partition) val log = new Log( dir, LogConfig(logProps), 0L, + 0L, mockTime.scheduler, - mockTime + new BrokerTopicStats(), + mockTime, + Int.MaxValue, + Int.MaxValue, + topicPartition, + new ProducerStateManager(topicPartition, dir), + logDirFailureChannel ) messages.foreach { case (k, v) => - val msg = new ByteBufferMessageSet( - NoCompressionCodec, - new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue)) - log.append(msg) + val record = new SimpleRecord(k.getBytes, v.getBytes) + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, record), 0); } log.roll() - logs.put(TopicAndPartition(topic, partition), log) + logs.put(topicPartition, log) - val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs) + val cleaner = new LogCleaner(CleanerConfig(), Array(dir), logs, logDirFailureChannel) cleaner.startup() - cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000) + cleaner.awaitCleaned(new TopicPartition(topic, partition), log.activeSegment.baseOffset, 1000) cleaner.shutdown() mockTime.scheduler.shutdown() diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 70b579d96d692..efcd5d6a5cdd3 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -32,13 +32,14 @@ import kafka.api.Request import kafka.server.{KafkaConfig, KafkaServer} import kafka.utils.ZkUtils import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.StringSerializer import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.streaming.Time -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -49,17 +50,17 @@ import org.apache.spark.util.Utils private[kafka010] class KafkaTestUtils extends Logging { // Zookeeper related configurations - private val zkHost = "localhost" + private val zkHost = "127.0.0.1" private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 - private val zkSessionTimeout = 6000 + private val zkSessionTimeout = 10000 private var zookeeper: EmbeddedZookeeper = _ private var zkUtils: ZkUtils = _ // Kafka broker related configurations - private val brokerHost = "localhost" + private val brokerHost = "127.0.0.1" private var brokerPort = 0 private var brokerConf: KafkaConfig = _ @@ -72,6 +73,7 @@ private[kafka010] class KafkaTestUtils extends Logging { // Flag to test whether the system is correctly started private var zkReady = false private var brokerReady = false + private var leakDetector: AnyRef = null def zkAddress: String = { assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") @@ -109,7 +111,7 @@ private[kafka010] class KafkaTestUtils extends Logging { brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) server = new KafkaServer(brokerConf) server.startup() - brokerPort = server.boundPort() + brokerPort = server.boundPort(new ListenerName("PLAINTEXT")) (server, brokerPort) }, new SparkConf(), "KafkaBroker") @@ -118,12 +120,22 @@ private[kafka010] class KafkaTestUtils extends Logging { /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ def setup(): Unit = { + // Set up a KafkaTestUtils leak detector so that we can see where the leak KafkaTestUtils is + // created. + val exception = new SparkException("It was created at: ") + leakDetector = ShutdownHookManager.addShutdownHook { () => + logError("Found a leak KafkaTestUtils.", exception) + } + setupEmbeddedZookeeper() setupEmbeddedKafkaServer() } /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { + if (leakDetector != null) { + ShutdownHookManager.removeShutdownHook(leakDetector) + } brokerReady = false zkReady = false @@ -216,12 +228,18 @@ private[kafka010] class KafkaTestUtils extends Logging { private def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") - props.put("host.name", "localhost") + props.put("host.name", "127.0.0.1") + props.put("advertised.host.name", "127.0.0.1") props.put("port", brokerPort.toString) props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) + props.put("zookeeper.connection.timeout.ms", "60000") props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") + props.put("delete.topic.enable", "true") + props.put("offsets.topic.num.partitions", "1") + props.put("offsets.topic.replication.factor", "1") + props.put("group.initial.rebalance.delay.ms", "10") props } @@ -270,12 +288,10 @@ private[kafka010] class KafkaTestUtils extends Logging { private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => - val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr - + val leader = partitionState.basePartitionState.leader + val isr = partitionState.basePartitionState.isr zkUtils.getLeaderForPartition(topic, partition).isDefined && - Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && - leaderAndInSyncReplicas.isr.nonEmpty - + Request.isValidBrokerId(leader) && !isr.isEmpty case _ => false } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala index 928e1a6ef54b9..4811d041e7e9e 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala @@ -21,7 +21,8 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable.PriorityQueue -import kafka.utils.{Scheduler, Time} +import kafka.utils.Scheduler +import org.apache.kafka.common.utils.Time /** * A mock scheduler that executes tasks synchronously using a mock time instance. diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala index a68f94db1f689..8a8646ee4eb94 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010.mocks import java.util.concurrent._ -import kafka.utils.Time +import org.apache.kafka.common.utils.Time /** * A class used for unit testing things which depend on the Time interface. @@ -36,12 +36,14 @@ private[kafka010] class MockTime(@volatile private var currentMs: Long) extends def this() = this(System.currentTimeMillis) - def milliseconds: Long = currentMs + override def milliseconds: Long = currentMs - def nanoseconds: Long = + override def hiResClockMs(): Long = milliseconds + + override def nanoseconds: Long = TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) - def sleep(ms: Long) { + override def sleep(ms: Long) { this.currentMs += ms scheduler.tick() } diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 41bc8b3e3ee1f..6be17a81f3fed 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -95,11 +95,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.scala-lang scala-library diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 5ea52b6ad36a0..791cf0efaf888 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -191,6 +191,7 @@ class KafkaRDD[ private def fetchBatch: Iterator[MessageAndOffset] = { val req = new FetchRequestBuilder() + .clientId(consumer.clientId) .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) .build() val resp = consumer.fetch(req) diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 37c7d1e604ec5..68fded515626b 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -89,11 +89,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.apache.hadoop hadoop-client diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index fa0de6298a5f1..69c52365b1bf8 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -160,7 +160,6 @@ private[kinesis] class KinesisReceiver[T]( cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPosition.getPosition) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) @@ -169,7 +168,8 @@ private[kinesis] class KinesisReceiver[T]( initialPosition match { case ts: AtTimestamp => baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) - case _ => baseClientLibConfiguration + case _ => + baseClientLibConfiguration.withInitialPositionInStream(initialPosition.getPosition) } } diff --git a/graphx/pom.xml b/graphx/pom.xml index fbe77fcb958d5..0f5dc548600b2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -53,7 +53,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded com.google.guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index ebd65e8320e5c..96b635f9a144e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -184,9 +184,11 @@ object PageRank extends Logging { * indexed by the position of nodes in the sources list) and * edge attributes the normalized edge weight */ - def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], - numIter: Int, resetProb: Double = 0.15, - sources: Array[VertexId]): Graph[Vector, Double] = { + def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], + numIter: Int, + resetProb: Double = 0.15, + sources: Array[VertexId]): Graph[Vector, Double] = { require(numIter > 0, s"Number of iterations must be greater than 0," + s" but got ${numIter}") require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + @@ -194,15 +196,11 @@ object PageRank extends Logging { require(sources.nonEmpty, s"The list of sources must be non-empty," + s" but got ${sources.mkString("[", ",", "]")}") - // TODO if one sources vertex id is outside of the int range - // we won't be able to store its activations in a sparse vector - require(sources.max <= Int.MaxValue.toLong, - s"This implementation currently only works for source vertex ids at most ${Int.MaxValue}") val zero = Vectors.sparse(sources.size, List()).asBreeze - val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) => - val v = Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze - (vid, v) - }.toMap + // map of vid -> vector where for each vid, the _position of vid in source_ is set to 1.0 + val sourcesInitMap = sources.zipWithIndex.toMap.mapValues { i => + Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze + } val sc = graph.vertices.sparkContext val sourcesInitMapBC = sc.broadcast(sourcesInitMap) // Initialize the PageRank graph with each edge attribute having @@ -212,13 +210,7 @@ object PageRank extends Logging { .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree .mapTriplets(e => 1.0 / e.srcAttr, TripletFields.Src) - .mapVertices { (vid, attr) => - if (sourcesInitMapBC.value contains vid) { - sourcesInitMapBC.value(vid) - } else { - zero - } - } + .mapVertices((vid, _) => sourcesInitMapBC.value.getOrElse(vid, zero)) var i = 0 while (i < numIter) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index d76e84ed8c9ed..50b03f71379a1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.util.Utils @@ -109,14 +109,14 @@ private[graphx] object BytecodeUtils { * determine the actual method invoked by inspecting the bytecode. */ private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM5) { + extends ClassVisitor(ASM6) { val methodsInvoked = new HashSet[(Class[_], String)] override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name == methodName) { - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitMethodInsn( op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 9779553ce85d1..1e4c6c74bd184 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -203,24 +203,42 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { test("Chain PersonalizedPageRank") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x + 1) ) + // Check that implementation can handle large vertexIds, SPARK-25149 + val vertexIdOffset = Int.MaxValue.toLong + 1 + val sourceOffest = 4 + val source = vertexIdOffset + sourceOffest + val numIter = 10 + val vertices = vertexIdOffset until vertexIdOffset + numIter + val chain1 = vertices.zip(vertices.tail) val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 - val numIter = 10 val errorTol = 1.0e-1 - val staticRanks = chain.staticPersonalizedPageRank(4, numIter, resetProb).vertices - val dynamicRanks = chain.personalizedPageRank(4, tol, resetProb).vertices + val a = resetProb / (1 - Math.pow(1 - resetProb, numIter - sourceOffest)) + // We expect the rank to decay as (1 - resetProb) ^ distance + val expectedRanks = sc.parallelize(vertices).map { vid => + val rank = if (vid < source) { + 0.0 + } else { + a * Math.pow(1 - resetProb, vid - source) + } + vid -> rank + } + val expected = VertexRDD(expectedRanks) + + val staticRanks = chain.staticPersonalizedPageRank(source, numIter, resetProb).vertices + assert(compareRanks(staticRanks, expected) < errorTol) - assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + val dynamicRanks = chain.personalizedPageRank(source, tol, resetProb).vertices + assert(compareRanks(dynamicRanks, expected) < errorTol) val parallelStaticRanks = chain - .staticParallelPersonalizedPageRank(Array(4), numIter, resetProb).mapVertices { + .staticParallelPersonalizedPageRank(Array(source), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(0) }.vertices.cache() - assert(compareRanks(staticRanks, parallelStaticRanks) < errorTol) + assert(compareRanks(parallelStaticRanks, expected) < errorTol) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index 61e44dcab578c..5325978a0a1ec 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ClosureCleanerSuite2 // scalastyle:off println @@ -26,6 +27,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass test("closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); } assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo")) assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar")) @@ -43,6 +45,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure inside a closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); } val c2 = {e: TestClass => c1(e); println(e.foo); } assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "foo")) @@ -51,6 +54,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure inside a closure inside a closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.baz); } val c2 = {e: TestClass => c1(e); println(e.foo); } val c3 = {e: TestClass => c2(e) } @@ -60,6 +64,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure calling a function that invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) def zoo(e: TestClass) { println(e.baz) } @@ -70,6 +75,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure calling a function that invokes a method which uses another closure") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c2 = {e: TestClass => println(e.baz)} def zoo(e: TestClass) { c2(e) @@ -81,6 +87,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("nested closure") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c2 = {e: TestClass => println(e.baz)} def zoo(e: TestClass, c: TestClass => Unit) { c(e) diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 8e424b1c50236..2c39a7df0146e 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -38,7 +38,32 @@ hadoop-cloud + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + provided + + + + hadoop-3.1 + + + + org.apache.hadoop + hadoop-cloud-storage + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + + + org.eclipse.jetty + jetty-util-ajax + ${jetty.version} + ${hadoop.deps.scope} + + + + diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index 4e02843480e8f..8a1256f73416e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -28,7 +28,7 @@ * * @since Spark 2.3.0 */ -public abstract class AbstractLauncher { +public abstract class AbstractLauncher> { final SparkSubmitCommandBuilder builder; diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 1e34bb8c73279..d967aa39a4827 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -54,10 +55,12 @@ public static void main(String[] argsArray) throws Exception { String className = args.remove(0); boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - AbstractCommandBuilder builder; + Map env = new HashMap<>(); + List cmd; if (className.equals("org.apache.spark.deploy.SparkSubmit")) { try { - builder = new SparkSubmitCommandBuilder(args); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(args); + cmd = buildCommand(builder, env, printLaunchCommand); } catch (IllegalArgumentException e) { printLaunchCommand = false; System.err.println("Error: " + e.getMessage()); @@ -76,17 +79,12 @@ public static void main(String[] argsArray) throws Exception { help.add(parser.className); } help.add(parser.USAGE_ERROR); - builder = new SparkSubmitCommandBuilder(help); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(help); + cmd = buildCommand(builder, env, printLaunchCommand); } } else { - builder = new SparkClassCommandBuilder(className, args); - } - - Map env = new HashMap<>(); - List cmd = builder.buildCommand(env); - if (printLaunchCommand) { - System.err.println("Spark Command: " + join(" ", cmd)); - System.err.println("========================================"); + AbstractCommandBuilder builder = new SparkClassCommandBuilder(className, args); + cmd = buildCommand(builder, env, printLaunchCommand); } if (isWindows()) { @@ -101,6 +99,22 @@ public static void main(String[] argsArray) throws Exception { } } + /** + * Prepare spark commands with the appropriate command builder. + * If printLaunchCommand is set then the commands will be printed to the stderr. + */ + private static List buildCommand( + AbstractCommandBuilder builder, + Map env, + boolean printLaunchCommand) throws IOException, IllegalArgumentException { + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + return cmd; + } + /** * Prepare a command line for execution from a Windows batch script. * diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5cb6457bf5c21..cc65f78b45c30 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -90,7 +90,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { final List userArgs; private final List parsedArgs; - private final boolean requiresAppResource; + // Special command means no appResource and no mainClass required + private final boolean isSpecialCommand; private final boolean isExample; /** @@ -105,7 +106,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { * spark-submit argument list to be modified after creation. */ SparkSubmitCommandBuilder() { - this.requiresAppResource = true; + this.isSpecialCommand = false; this.isExample = false; this.parsedArgs = new ArrayList<>(); this.userArgs = new ArrayList<>(); @@ -138,25 +139,26 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { case RUN_EXAMPLE: isExample = true; + appResource = SparkLauncher.NO_RESOURCE; submitArgs = args.subList(1, args.size()); } this.isExample = isExample; OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.requiresAppResource = parser.requiresAppResource; + this.isSpecialCommand = parser.isSpecialCommand; } else { this.isExample = isExample; - this.requiresAppResource = false; + this.isSpecialCommand = true; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { + if (PYSPARK_SHELL.equals(appResource) && !isSpecialCommand) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { + } else if (SPARKR_SHELL.equals(appResource) && !isSpecialCommand) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -166,18 +168,18 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); OptionParser parser = new OptionParser(false); - final boolean requiresAppResource; + final boolean isSpecialCommand; // If the user args array is not empty, we need to parse it to detect exactly what // the user is trying to run, so that checks below are correct. if (!userArgs.isEmpty()) { parser.parse(userArgs); - requiresAppResource = parser.requiresAppResource; + isSpecialCommand = parser.isSpecialCommand; } else { - requiresAppResource = this.requiresAppResource; + isSpecialCommand = this.isSpecialCommand; } - if (!allowsMixedArguments && requiresAppResource) { + if (!allowsMixedArguments && !isSpecialCommand) { checkArgument(appResource != null, "Missing application resource."); } @@ -229,7 +231,7 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isExample) { + if (isExample && !isSpecialCommand) { checkArgument(mainClass != null, "Missing example class name."); } @@ -421,7 +423,7 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean requiresAppResource = true; + boolean isSpecialCommand = false; private final boolean errorOnUnknownArgs; OptionParser(boolean errorOnUnknownArgs) { @@ -470,17 +472,14 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - requiresAppResource = false; - parsedArgs.add(opt); - break; case VERSION: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); break; default: diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 2e050f8413074..b343094b2e7b8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -27,7 +28,10 @@ import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + import static org.junit.Assert.*; public class SparkSubmitCommandBuilderSuite extends BaseSuite { @@ -35,6 +39,9 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @BeforeClass public static void setUp() throws Exception { dummyPropsFile = File.createTempFile("spark", "properties"); @@ -74,8 +81,11 @@ public void testCliHelpAndNoArg() throws Exception { @Test public void testCliKillAndStatus() throws Exception { - testCLIOpts(parser.STATUS); - testCLIOpts(parser.KILL_SUBMISSION); + List params = Arrays.asList("driver-20160531171222-0000"); + testCLIOpts(null, parser.STATUS, params); + testCLIOpts(null, parser.KILL_SUBMISSION, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.STATUS, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.KILL_SUBMISSION, params); } @Test @@ -190,6 +200,33 @@ public void testSparkRShell() throws Exception { env.get("SPARKR_SUBMIT_ARGS")); } + @Test(expected = IllegalArgumentException.class) + public void testExamplesRunnerNoArg() throws Exception { + List sparkSubmitArgs = Arrays.asList(SparkSubmitCommandBuilder.RUN_EXAMPLE); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + + @Test + public void testExamplesRunnerNoMainClass() throws Exception { + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.HELP, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.USAGE_ERROR, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.VERSION, null); + } + + @Test + public void testExamplesRunnerWithMasterNoMainClass() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing example class name."); + + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.RUN_EXAMPLE, + parser.MASTER + "=foo" + ); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + @Test public void testExamplesRunner() throws Exception { List sparkSubmitArgs = Arrays.asList( @@ -344,10 +381,17 @@ private List buildCommand(List args, Map env) th return newCommandBuilder(args).buildCommand(env); } - private void testCLIOpts(String opt) throws Exception { - List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000"); + private void testCLIOpts(String appResource, String opt, List params) throws Exception { + List args = new ArrayList<>(); + if (appResource != null) { + args.add(appResource); + } + args.add(opt); + if (params != null) { + args.addAll(params); + } Map env = new HashMap<>(); - List cmd = buildCommand(helpArgs, env); + List cmd = buildCommand(args, env); assertTrue(opt + " should be contained in the final cmd.", cmd.contains(opt)); } diff --git a/licenses/LICENSE-scopt.txt b/licenses-binary/LICENSE-AnchorJS.txt similarity index 100% rename from licenses/LICENSE-scopt.txt rename to licenses-binary/LICENSE-AnchorJS.txt diff --git a/licenses-binary/LICENSE-CC0.txt b/licenses-binary/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses-binary/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-antlr.txt b/licenses-binary/LICENSE-antlr.txt similarity index 100% rename from licenses/LICENSE-antlr.txt rename to licenses-binary/LICENSE-antlr.txt diff --git a/licenses-binary/LICENSE-arpack.txt b/licenses-binary/LICENSE-arpack.txt new file mode 100644 index 0000000000000..a3ad80087bb63 --- /dev/null +++ b/licenses-binary/LICENSE-arpack.txt @@ -0,0 +1,8 @@ +Copyright © 2018 The University of Tennessee. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +· Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +· Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. +· Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +This software is provided by the copyright holders and contributors "as is" and any express or implied warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. in no event shall the copyright owner or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. \ No newline at end of file diff --git a/licenses-binary/LICENSE-automaton.txt b/licenses-binary/LICENSE-automaton.txt new file mode 100644 index 0000000000000..2fc6e8c3432f0 --- /dev/null +++ b/licenses-binary/LICENSE-automaton.txt @@ -0,0 +1,24 @@ +Copyright (c) 2001-2017 Anders Moeller +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +3. The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-bootstrap.txt b/licenses-binary/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses-binary/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses-binary/LICENSE-cloudpickle.txt b/licenses-binary/LICENSE-cloudpickle.txt new file mode 100644 index 0000000000000..b1e20fa1eda88 --- /dev/null +++ b/licenses-binary/LICENSE-cloudpickle.txt @@ -0,0 +1,28 @@ +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-d3.min.js.txt b/licenses-binary/LICENSE-d3.min.js.txt new file mode 100644 index 0000000000000..c71e3f254c068 --- /dev/null +++ b/licenses-binary/LICENSE-d3.min.js.txt @@ -0,0 +1,26 @@ +Copyright (c) 2010-2015, Michael Bostock +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name Michael Bostock may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-Mockito.txt b/licenses-binary/LICENSE-dagre-d3.txt similarity index 94% rename from licenses/LICENSE-Mockito.txt rename to licenses-binary/LICENSE-dagre-d3.txt index e0840a446caf5..4864fe05e9803 100644 --- a/licenses/LICENSE-Mockito.txt +++ b/licenses-binary/LICENSE-dagre-d3.txt @@ -1,6 +1,4 @@ -The MIT License - -Copyright (c) 2007 Mockito contributors +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses-binary/LICENSE-datatables.txt b/licenses-binary/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses-binary/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-f2j.txt b/licenses-binary/LICENSE-f2j.txt similarity index 100% rename from licenses/LICENSE-f2j.txt rename to licenses-binary/LICENSE-f2j.txt diff --git a/licenses-binary/LICENSE-graphlib-dot.txt b/licenses-binary/LICENSE-graphlib-dot.txt new file mode 100644 index 0000000000000..4864fe05e9803 --- /dev/null +++ b/licenses-binary/LICENSE-graphlib-dot.txt @@ -0,0 +1,19 @@ +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-heapq.txt b/licenses-binary/LICENSE-heapq.txt new file mode 100644 index 0000000000000..0c4c4b954bea4 --- /dev/null +++ b/licenses-binary/LICENSE-heapq.txt @@ -0,0 +1,280 @@ + +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-janino.txt b/licenses-binary/LICENSE-janino.txt new file mode 100644 index 0000000000000..d1e1f237c4641 --- /dev/null +++ b/licenses-binary/LICENSE-janino.txt @@ -0,0 +1,31 @@ +Janino - An embedded Java[TM] compiler + +Copyright (c) 2001-2016, Arno Unkrig +Copyright (c) 2015-2016 TIBCO Software Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + 3. Neither the name of JANINO nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-javassist.html b/licenses-binary/LICENSE-javassist.html new file mode 100644 index 0000000000000..5abd563a0c4d9 --- /dev/null +++ b/licenses-binary/LICENSE-javassist.html @@ -0,0 +1,373 @@ + + + Javassist License + + + + +
    MOZILLA PUBLIC LICENSE
    Version + 1.1 +

    +


    +
    +

    1. Definitions. +

      1.0.1. "Commercial Use" means distribution or otherwise making the + Covered Code available to a third party. +

      1.1. ''Contributor'' means each entity that creates or contributes + to the creation of Modifications. +

      1.2. ''Contributor Version'' means the combination of the Original + Code, prior Modifications used by a Contributor, and the Modifications made by + that particular Contributor. +

      1.3. ''Covered Code'' means the Original Code or Modifications or + the combination of the Original Code and Modifications, in each case including + portions thereof. +

      1.4. ''Electronic Distribution Mechanism'' means a mechanism + generally accepted in the software development community for the electronic + transfer of data. +

      1.5. ''Executable'' means Covered Code in any form other than Source + Code. +

      1.6. ''Initial Developer'' means the individual or entity identified + as the Initial Developer in the Source Code notice required by Exhibit + A. +

      1.7. ''Larger Work'' means a work which combines Covered Code or + portions thereof with code not governed by the terms of this License. +

      1.8. ''License'' means this document. +

      1.8.1. "Licensable" means having the right to grant, to the maximum + extent possible, whether at the time of the initial grant or subsequently + acquired, any and all of the rights conveyed herein. +

      1.9. ''Modifications'' means any addition to or deletion from the + substance or structure of either the Original Code or any previous + Modifications. When Covered Code is released as a series of files, a + Modification is: +

        A. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +

        B. Any new file that contains any part of the Original Code or + previous Modifications.
         

      1.10. ''Original Code'' +means Source Code of computer software code which is described in the Source +Code notice required by Exhibit A as Original Code, and which, at the +time of its release under this License is not already Covered Code governed by +this License. +

      1.10.1. "Patent Claims" means any patent claim(s), now owned or + hereafter acquired, including without limitation,  method, process, and + apparatus claims, in any patent Licensable by grantor. +

      1.11. ''Source Code'' means the preferred form of the Covered Code + for making modifications to it, including all modules it contains, plus any + associated interface definition files, scripts used to control compilation and + installation of an Executable, or source code differential comparisons against + either the Original Code or another well known, available Covered Code of the + Contributor's choice. The Source Code can be in a compressed or archival form, + provided the appropriate decompression or de-archiving software is widely + available for no charge. +

      1.12. "You'' (or "Your")  means an individual or a legal entity + exercising rights under, and complying with all of the terms of, this License + or a future version of this License issued under Section 6.1. For legal + entities, "You'' includes any entity which controls, is controlled by, or is + under common control with You. For purposes of this definition, "control'' + means (a) the power, direct or indirect, to cause the direction or management + of such entity, whether by contract or otherwise, or (b) ownership of more + than fifty percent (50%) of the outstanding shares or beneficial ownership of + such entity.

    2. Source Code License. +
      2.1. The Initial Developer Grant.
      The Initial Developer hereby + grants You a world-wide, royalty-free, non-exclusive license, subject to third + party intellectual property claims: +
        (a)  under intellectual property rights (other than + patent or trademark) Licensable by Initial Developer to use, reproduce, + modify, display, perform, sublicense and distribute the Original Code (or + portions thereof) with or without Modifications, and/or as part of a Larger + Work; and +

        (b) under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for + sale, and/or otherwise dispose of the Original Code (or portions thereof). +

          +
          (c) the licenses granted in this Section 2.1(a) and (b) + are effective on the date Initial Developer first distributes Original Code + under the terms of this License. +

          (d) Notwithstanding Section 2.1(b) above, no patent license is + granted: 1) for code that You delete from the Original Code; 2) separate + from the Original Code;  or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original + Code with other software or devices.
           

        2.2. Contributor + Grant.
        Subject to third party intellectual property claims, each + Contributor hereby grants You a world-wide, royalty-free, non-exclusive + license +

          (a)  under intellectual property rights (other + than patent or trademark) Licensable by Contributor, to use, reproduce, + modify, display, perform, sublicense and distribute the Modifications + created by such Contributor (or portions thereof) either on an unmodified + basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +

          (b) under Patent Claims infringed by the making, using, or selling + of  Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such + combination), to make, use, sell, offer for sale, have made, and/or + otherwise dispose of: 1) Modifications made by that Contributor (or portions + thereof); and 2) the combination of  Modifications made by that + Contributor with its Contributor Version (or portions of such + combination). +

          (c) the licenses granted in Sections 2.2(a) and 2.2(b) are + effective on the date Contributor first makes Commercial Use of the Covered + Code. +

          (d)    Notwithstanding Section 2.2(b) above, no + patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2)  separate from the Contributor + Version;  3)  for infringements caused by: i) third party + modifications of Contributor Version or ii)  the combination of + Modifications made by that Contributor with other software  (except as + part of the Contributor Version) or other devices; or 4) under Patent Claims + infringed by Covered Code in the absence of Modifications made by that + Contributor.

      +


      3. Distribution Obligations. +

        3.1. Application of License.
        The Modifications which You create + or to which You contribute are governed by the terms of this License, + including without limitation Section 2.2. The Source Code version of + Covered Code may be distributed only under the terms of this License or a + future version of this License released under Section 6.1, and You must + include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version + that alters or restricts the applicable version of this License or the + recipients' rights hereunder. However, You may include an additional document + offering the additional rights described in Section 3.5. +

        3.2. Availability of Source Code.
        Any Modification which You + create or to which You contribute must be made available in Source Code form + under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom + you made an Executable version available; and if made available via Electronic + Distribution Mechanism, must remain available for at least twelve (12) months + after the date it initially became available, or at least six (6) months after + a subsequent version of that particular Modification has been made available + to such recipients. You are responsible for ensuring that the Source Code + version remains available even if the Electronic Distribution Mechanism is + maintained by a third party. +

        3.3. Description of Modifications.
        You must cause all Covered + Code to which You contribute to contain a file documenting the changes You + made to create that Covered Code and the date of any change. You must include + a prominent statement that the Modification is derived, directly or + indirectly, from Original Code provided by the Initial Developer and including + the name of the Initial Developer in (a) the Source Code, and (b) in any + notice in an Executable version or related documentation in which You describe + the origin or ownership of the Covered Code. +

        3.4. Intellectual Property Matters +

          (a) Third Party Claims.
          If Contributor has knowledge that a + license under a third party's intellectual property rights is required to + exercise the rights granted by such Contributor under Sections 2.1 or 2.2, + Contributor must include a text file with the Source Code distribution + titled "LEGAL'' which describes the claim and the party making the claim in + sufficient detail that a recipient will know whom to contact. If Contributor + obtains such knowledge after the Modification is made available as described + in Section 3.2, Contributor shall promptly modify the LEGAL file in all + copies Contributor makes available thereafter and shall take other steps + (such as notifying appropriate mailing lists or newsgroups) reasonably + calculated to inform those who received the Covered Code that new knowledge + has been obtained. +

          (b) Contributor APIs.
          If Contributor's Modifications include + an application programming interface and Contributor has knowledge of patent + licenses which are reasonably necessary to implement that API, Contributor + must also include this information in the LEGAL file. +
           

                  +(c)    Representations. +
          Contributor represents that, except as disclosed pursuant to Section + 3.4(a) above, Contributor believes that Contributor's Modifications are + Contributor's original creation(s) and/or Contributor has sufficient rights + to grant the rights conveyed by this License.
        +


        3.5. Required Notices.
        You must duplicate the notice in + Exhibit A in each file of the Source Code.  If it is not possible + to put such notice in a particular Source Code file due to its structure, then + You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice.  If You created + one or more Modification(s) You may add your name as a Contributor to the + notice described in Exhibit A.  You must also duplicate this + License in any documentation for the Source Code where You describe + recipients' rights or ownership rights relating to Covered Code.  You may + choose to offer, and to charge a fee for, warranty, support, indemnity or + liability obligations to one or more recipients of Covered Code. However, You + may do so only on Your own behalf, and not on behalf of the Initial Developer + or any Contributor. You must make it absolutely clear than any such warranty, + support, indemnity or liability obligation is offered by You alone, and You + hereby agree to indemnify the Initial Developer and every Contributor for any + liability incurred by the Initial Developer or such Contributor as a result of + warranty, support, indemnity or liability terms You offer. +

        3.6. Distribution of Executable Versions.
        You may distribute + Covered Code in Executable form only if the requirements of Section + 3.1-3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available + under the terms of this License, including a description of how and where You + have fulfilled the obligations of Section 3.2. The notice must be + conspicuously included in any notice in an Executable version, related + documentation or collateral in which You describe recipients' rights relating + to the Covered Code. You may distribute the Executable version of Covered Code + or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the + terms of this License and that the license for the Executable version does not + attempt to limit or alter the recipient's rights in the Source Code version + from the rights set forth in this License. If You distribute the Executable + version under a different license You must make it absolutely clear that any + terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the + Initial Developer and every Contributor for any liability incurred by the + Initial Developer or such Contributor as a result of any such terms You offer. + +

        3.7. Larger Works.
        You may create a Larger Work by combining + Covered Code with other code not governed by the terms of this License and + distribute the Larger Work as a single product. In such a case, You must make + sure the requirements of this License are fulfilled for the Covered + Code.

      4. Inability to Comply Due to Statute or Regulation. +
        If it is impossible for You to comply with any of the terms of this + License with respect to some or all of the Covered Code due to statute, + judicial order, or regulation then You must: (a) comply with the terms of this + License to the maximum extent possible; and (b) describe the limitations and + the code they affect. Such description must be included in the LEGAL file + described in Section 3.4 and must be included with all distributions of + the Source Code. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it.
      5. Application of this License. +
        This License applies to code to which the Initial Developer has attached + the notice in Exhibit A and to related Covered Code.
      6. Versions + of the License. +
        6.1. New Versions.
        Netscape Communications Corporation + (''Netscape'') may publish revised and/or new versions of the License from + time to time. Each version will be given a distinguishing version number. +

        6.2. Effect of New Versions.
        Once Covered Code has been + published under a particular version of the License, You may always continue + to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License + published by Netscape. No one other than Netscape has the right to modify the + terms applicable to Covered Code created under this License. +

        6.3. Derivative Works.
        If You create or use a modified version + of this License (which you may only do in order to apply it to code which is + not already Covered Code governed by this License), You must (a) rename Your + license so that the phrases ''Mozilla'', ''MOZILLAPL'', ''MOZPL'', + ''Netscape'', "MPL", ''NPL'' or any confusingly similar phrase do not appear + in your license (except to note that your license differs from this License) + and (b) otherwise make it clear that Your version of the license contains + terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or + Contributor in the notice described in Exhibit A shall not of + themselves be deemed to be modifications of this License.)

      7. + DISCLAIMER OF WARRANTY. +
        COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS'' BASIS, WITHOUT + WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT + LIMITATION, WARRANTIES THAT THE COVERED CODE IS FREE OF DEFECTS, MERCHANTABLE, + FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE + QUALITY AND PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED + CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY + OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR + CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS + LICENSE. NO USE OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS + DISCLAIMER.
      8. TERMINATION. +
        8.1.  This License and the rights granted hereunder will + terminate automatically if You fail to comply with terms herein and fail to + cure such breach within 30 days of becoming aware of the breach. All + sublicenses to the Covered Code which are properly granted shall survive any + termination of this License. Provisions which, by their nature, must remain in + effect beyond the termination of this License shall survive. +

        8.2.  If You initiate litigation by asserting a patent + infringement claim (excluding declatory judgment actions) against Initial + Developer or a Contributor (the Initial Developer or Contributor against whom + You file such action is referred to as "Participant")  alleging that: +

        (a)  such Participant's Contributor Version directly or + indirectly infringes any patent, then any and all rights granted by such + Participant to You under Sections 2.1 and/or 2.2 of this License shall, upon + 60 days notice from Participant terminate prospectively, unless if within 60 + days after receipt of notice You either: (i)  agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future + use of Modifications made by such Participant, or (ii) withdraw Your + litigation claim with respect to the Contributor Version against such + Participant.  If within 60 days of notice, a reasonable royalty and + payment arrangement are not mutually agreed upon in writing by the parties or + the litigation claim is not withdrawn, the rights granted by Participant to + You under Sections 2.1 and/or 2.2 automatically terminate at the expiration of + the 60 day notice period specified above. +

        (b)  any software, hardware, or device, other than such + Participant's Contributor Version, directly or indirectly infringes any + patent, then any rights granted to You by such Participant under Sections + 2.1(b) and 2.2(b) are revoked effective as of the date You first made, used, + sold, distributed, or had made, Modifications made by that Participant. +

        8.3.  If You assert a patent infringement claim against + Participant alleging that such Participant's Contributor Version directly or + indirectly infringes any patent where such claim is resolved (such as by + license or settlement) prior to the initiation of patent infringement + litigation, then the reasonable value of the licenses granted by such + Participant under Sections 2.1 or 2.2 shall be taken into account in + determining the amount or value of any payment or license. +

        8.4.  In the event of termination under Sections 8.1 or 8.2 + above,  all end user license agreements (excluding distributors and + resellers) which have been validly granted by You or any distributor hereunder + prior to termination shall survive termination.

      9. LIMITATION OF + LIABILITY. +
        UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING + NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY + OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED CODE, OR ANY SUPPLIER OF ANY + OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, + INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT + LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR + MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH + PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS + LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL + INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW + PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR + LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND + LIMITATION MAY NOT APPLY TO YOU.
      10. U.S. GOVERNMENT END USERS. +
        The Covered Code is a ''commercial item,'' as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of ''commercial computer software'' and + ''commercial computer software documentation,'' as such terms are used in 48 + C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein.
      11. + MISCELLANEOUS. +
        This License represents the complete agreement concerning subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. This License shall be governed by California law provisions + (except to the extent applicable law, if any, provides otherwise), excluding + its conflict-of-law provisions. With respect to disputes in which at least one + party is a citizen of, or an entity chartered or registered to do business in + the United States of America, any litigation relating to this License shall be + subject to the jurisdiction of the Federal Courts of the Northern District of + California, with venue lying in Santa Clara County, California, with the + losing party responsible for costs, including without limitation, court costs + and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is + expressly excluded. Any law or regulation which provides that the language of + a contract shall be construed against the drafter shall not apply to this + License.
      12. RESPONSIBILITY FOR CLAIMS. +
        As between Initial Developer and the Contributors, each party is + responsible for claims and damages arising, directly or indirectly, out of its + utilization of rights under this License and You agree to work with Initial + Developer and Contributors to distribute such responsibility on an equitable + basis. Nothing herein is intended or shall be deemed to constitute any + admission of liability.
      13. MULTIPLE-LICENSED CODE. +
        Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed".  "Multiple-Licensed" means that the Initial + Developer permits you to utilize portions of the Covered Code under Your + choice of the MPL or the alternative licenses, if any, specified by the + Initial Developer in the file described in Exhibit A.
      +


      EXHIBIT A -Mozilla Public License. +

        The contents of this file are subject to the Mozilla Public License + Version 1.1 (the "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at +
        http://www.mozilla.org/MPL/ +

        Software distributed under the License is distributed on an "AS IS" basis, + WITHOUT WARRANTY OF
        ANY KIND, either express or implied. See the License + for the specific language governing rights and
        limitations under the + License. +

        The Original Code is Javassist. +

        The Initial Developer of the Original Code is Shigeru Chiba. + Portions created by the Initial Developer are
          + Copyright (C) 1999- Shigeru Chiba. All Rights Reserved. +

        Contributor(s): __Bill Burke, Jason T. Greene______________. + +

        Alternatively, the contents of this software may be used under the + terms of the GNU Lesser General Public License Version 2.1 or later + (the "LGPL"), or the Apache License Version 2.0 (the "AL"), + in which case the provisions of the LGPL or the AL are applicable + instead of those above. If you wish to allow use of your version of + this software only under the terms of either the LGPL or the AL, and not to allow others to + use your version of this software under the terms of the MPL, indicate + your decision by deleting the provisions above and replace them with + the notice and other provisions required by the LGPL or the AL. If you do not + delete the provisions above, a recipient may use your version of this + software under the terms of any one of the MPL, the LGPL or the AL. + +

      + + \ No newline at end of file diff --git a/licenses/LICENSE-javolution.txt b/licenses-binary/LICENSE-javolution.txt similarity index 100% rename from licenses/LICENSE-javolution.txt rename to licenses-binary/LICENSE-javolution.txt diff --git a/licenses/LICENSE-jline.txt b/licenses-binary/LICENSE-jline.txt similarity index 100% rename from licenses/LICENSE-jline.txt rename to licenses-binary/LICENSE-jline.txt diff --git a/licenses/LICENSE-junit-interface.txt b/licenses-binary/LICENSE-jodd.txt similarity index 69% rename from licenses/LICENSE-junit-interface.txt rename to licenses-binary/LICENSE-jodd.txt index e835350c4e2a4..cc6b458adb386 100644 --- a/licenses/LICENSE-junit-interface.txt +++ b/licenses-binary/LICENSE-jodd.txt @@ -1,15 +1,15 @@ -Copyright (c) 2009-2012, Stefan Zeiger +Copyright (c) 2003-present, Jodd Team (https://jodd.org) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE diff --git a/licenses/LICENSE-DPark.txt b/licenses-binary/LICENSE-join.txt similarity index 100% rename from licenses/LICENSE-DPark.txt rename to licenses-binary/LICENSE-join.txt diff --git a/licenses-binary/LICENSE-jquery.txt b/licenses-binary/LICENSE-jquery.txt new file mode 100644 index 0000000000000..45930542204fb --- /dev/null +++ b/licenses-binary/LICENSE-jquery.txt @@ -0,0 +1,20 @@ +Copyright JS Foundation and other contributors, https://js.foundation/ + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-json-formatter.txt b/licenses-binary/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses-binary/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses-binary/LICENSE-jtransforms.html b/licenses-binary/LICENSE-jtransforms.html new file mode 100644 index 0000000000000..351c17412357b --- /dev/null +++ b/licenses-binary/LICENSE-jtransforms.html @@ -0,0 +1,388 @@ + + +Mozilla Public License version 1.1 + + + + +

      Mozilla Public License Version 1.1

      +

      1. Definitions.

      +
      +
      1.0.1. "Commercial Use" +
      means distribution or otherwise making the Covered Code available to a third party. +
      1.1. "Contributor" +
      means each entity that creates or contributes to the creation of Modifications. +
      1.2. "Contributor Version" +
      means the combination of the Original Code, prior Modifications used by a Contributor, + and the Modifications made by that particular Contributor. +
      1.3. "Covered Code" +
      means the Original Code or Modifications or the combination of the Original Code and + Modifications, in each case including portions thereof. +
      1.4. "Electronic Distribution Mechanism" +
      means a mechanism generally accepted in the software development community for the + electronic transfer of data. +
      1.5. "Executable" +
      means Covered Code in any form other than Source Code. +
      1.6. "Initial Developer" +
      means the individual or entity identified as the Initial Developer in the Source Code + notice required by Exhibit A. +
      1.7. "Larger Work" +
      means a work which combines Covered Code or portions thereof with code not governed + by the terms of this License. +
      1.8. "License" +
      means this document. +
      1.8.1. "Licensable" +
      means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently acquired, any and all of the rights + conveyed herein. +
      1.9. "Modifications" +
      +

      means any addition to or deletion from the substance or structure of either the + Original Code or any previous Modifications. When Covered Code is released as a + series of files, a Modification is: +

        +
      1. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +
      2. Any new file that contains any part of the Original Code or + previous Modifications. +
      +
      1.10. "Original Code" +
      means Source Code of computer software code which is described in the Source Code + notice required by Exhibit A as Original Code, and which, + at the time of its release under this License is not already Covered Code governed + by this License. +
      1.10.1. "Patent Claims" +
      means any patent claim(s), now owned or hereafter acquired, including without + limitation, method, process, and apparatus claims, in any patent Licensable by + grantor. +
      1.11. "Source Code" +
      means the preferred form of the Covered Code for making modifications to it, + including all modules it contains, plus any associated interface definition files, + scripts used to control compilation and installation of an Executable, or source + code differential comparisons against either the Original Code or another well known, + available Covered Code of the Contributor's choice. The Source Code can be in a + compressed or archival form, provided the appropriate decompression or de-archiving + software is widely available for no charge. +
      1.12. "You" (or "Your") +
      means an individual or a legal entity exercising rights under, and complying with + all of the terms of, this License or a future version of this License issued under + Section 6.1. For legal entities, "You" includes any entity + which controls, is controlled by, or is under common control with You. For purposes of + this definition, "control" means (a) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or otherwise, or (b) + ownership of more than fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. +
      +

      2. Source Code License.

      +

      2.1. The Initial Developer Grant.

      +

      The Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive + license, subject to third party intellectual property claims: +

        +
      1. under intellectual property rights (other than patent or + trademark) Licensable by Initial Developer to use, reproduce, modify, display, perform, + sublicense and distribute the Original Code (or portions thereof) with or without + Modifications, and/or as part of a Larger Work; and +
      2. under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for sale, and/or + otherwise dispose of the Original Code (or portions thereof). +
      3. the licenses granted in this Section 2.1 + (a) and (b) are effective on + the date Initial Developer first distributes Original Code under the terms of this + License. +
      4. Notwithstanding Section 2.1 (b) + above, no patent license is granted: 1) for code that You delete from the Original Code; + 2) separate from the Original Code; or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original Code with other + software or devices. +
      +

      2.2. Contributor Grant.

      +

      Subject to third party intellectual property claims, each Contributor hereby grants You + a world-wide, royalty-free, non-exclusive license +

        +
      1. under intellectual property rights (other than patent or trademark) + Licensable by Contributor, to use, reproduce, modify, display, perform, sublicense and + distribute the Modifications created by such Contributor (or portions thereof) either on + an unmodified basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +
      2. under Patent Claims infringed by the making, using, or selling of + Modifications made by that Contributor either alone and/or in combination with its + Contributor Version (or portions of such combination), to make, use, sell, offer for + sale, have made, and/or otherwise dispose of: 1) Modifications made by that Contributor + (or portions thereof); and 2) the combination of Modifications made by that Contributor + with its Contributor Version (or portions of such combination). +
      3. the licenses granted in Sections 2.2 + (a) and 2.2 (b) are effective + on the date Contributor first makes Commercial Use of the Covered Code. +
      4. Notwithstanding Section 2.2 (b) + above, no patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2) separate from the Contributor Version; 3) for infringements + caused by: i) third party modifications of Contributor Version or ii) the combination of + Modifications made by that Contributor with other software (except as part of the + Contributor Version) or other devices; or 4) under Patent Claims infringed by Covered Code + in the absence of Modifications made by that Contributor. +
      +

      3. Distribution Obligations.

      +

      3.1. Application of License.

      +

      The Modifications which You create or to which You contribute are governed by the terms + of this License, including without limitation Section 2.2. The + Source Code version of Covered Code may be distributed only under the terms of this License + or a future version of this License released under Section 6.1, + and You must include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version that alters or + restricts the applicable version of this License or the recipients' rights hereunder. + However, You may include an additional document offering the additional rights described in + Section 3.5. +

      3.2. Availability of Source Code.

      +

      Any Modification which You create or to which You contribute must be made available in + Source Code form under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom you made an + Executable version available; and if made available via Electronic Distribution Mechanism, + must remain available for at least twelve (12) months after the date it initially became + available, or at least six (6) months after a subsequent version of that particular + Modification has been made available to such recipients. You are responsible for ensuring + that the Source Code version remains available even if the Electronic Distribution + Mechanism is maintained by a third party. +

      3.3. Description of Modifications.

      +

      You must cause all Covered Code to which You contribute to contain a file documenting the + changes You made to create that Covered Code and the date of any change. You must include a + prominent statement that the Modification is derived, directly or indirectly, from Original + Code provided by the Initial Developer and including the name of the Initial Developer in + (a) the Source Code, and (b) in any notice in an Executable version or related documentation + in which You describe the origin or ownership of the Covered Code. +

      3.4. Intellectual Property Matters

      +

      (a) Third Party Claims

      +

      If Contributor has knowledge that a license under a third party's intellectual property + rights is required to exercise the rights granted by such Contributor under Sections + 2.1 or 2.2, Contributor must include a + text file with the Source Code distribution titled "LEGAL" which describes the claim and the + party making the claim in sufficient detail that a recipient will know whom to contact. If + Contributor obtains such knowledge after the Modification is made available as described in + Section 3.2, Contributor shall promptly modify the LEGAL file in + all copies Contributor makes available thereafter and shall take other steps (such as + notifying appropriate mailing lists or newsgroups) reasonably calculated to inform those who + received the Covered Code that new knowledge has been obtained. +

      (b) Contributor APIs

      +

      If Contributor's Modifications include an application programming interface and Contributor + has knowledge of patent licenses which are reasonably necessary to implement that + API, Contributor must also include this information in the + legal file. +

      (c) Representations.

      +

      Contributor represents that, except as disclosed pursuant to Section 3.4 + (a) above, Contributor believes that Contributor's Modifications + are Contributor's original creation(s) and/or Contributor has sufficient rights to grant the + rights conveyed by this License. +

      3.5. Required Notices.

      +

      You must duplicate the notice in Exhibit A in each file of the + Source Code. If it is not possible to put such notice in a particular Source Code file due to + its structure, then You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice. If You created one or more + Modification(s) You may add your name as a Contributor to the notice described in + Exhibit A. You must also duplicate this License in any documentation + for the Source Code where You describe recipients' rights or ownership rights relating to + Covered Code. You may choose to offer, and to charge a fee for, warranty, support, indemnity + or liability obligations to one or more recipients of Covered Code. However, You may do so + only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You + must make it absolutely clear than any such warranty, support, indemnity or liability + obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer + and every Contributor for any liability incurred by the Initial Developer or such Contributor + as a result of warranty, support, indemnity or liability terms You offer. +

      3.6. Distribution of Executable Versions.

      +

      You may distribute Covered Code in Executable form only if the requirements of Sections + 3.1, 3.2, + 3.3, 3.4 and + 3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available under the terms + of this License, including a description of how and where You have fulfilled the obligations + of Section 3.2. The notice must be conspicuously included in any + notice in an Executable version, related documentation or collateral in which You describe + recipients' rights relating to the Covered Code. You may distribute the Executable version of + Covered Code or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the terms of this + License and that the license for the Executable version does not attempt to limit or alter the + recipient's rights in the Source Code version from the rights set forth in this License. If + You distribute the Executable version under a different license You must make it absolutely + clear that any terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the Initial Developer and + every Contributor for any liability incurred by the Initial Developer or such Contributor as + a result of any such terms You offer. +

      3.7. Larger Works.

      +

      You may create a Larger Work by combining Covered Code with other code not governed by the + terms of this License and distribute the Larger Work as a single product. In such a case, + You must make sure the requirements of this License are fulfilled for the Covered Code. +

      4. Inability to Comply Due to Statute or Regulation.

      +

      If it is impossible for You to comply with any of the terms of this License with respect to + some or all of the Covered Code due to statute, judicial order, or regulation then You must: + (a) comply with the terms of this License to the maximum extent possible; and (b) describe + the limitations and the code they affect. Such description must be included in the + legal file described in Section + 3.4 and must be included with all distributions of the Source Code. + Except to the extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to understand it. +

      5. Application of this License.

      +

      This License applies to code to which the Initial Developer has attached the notice in + Exhibit A and to related Covered Code. +

      6. Versions of the License.

      +

      6.1. New Versions

      +

      Netscape Communications Corporation ("Netscape") may publish revised and/or new versions + of the License from time to time. Each version will be given a distinguishing version number. +

      6.2. Effect of New Versions

      +

      Once Covered Code has been published under a particular version of the License, You may + always continue to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License published by Netscape. + No one other than Netscape has the right to modify the terms applicable to Covered Code + created under this License. +

      6.3. Derivative Works

      +

      If You create or use a modified version of this License (which you may only do in order to + apply it to code which is not already Covered Code governed by this License), You must (a) + rename Your license so that the phrases "Mozilla", "MOZILLAPL", "MOZPL", "Netscape", "MPL", + "NPL" or any confusingly similar phrase do not appear in your license (except to note that + your license differs from this License) and (b) otherwise make it clear that Your version of + the license contains terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or Contributor in the + notice described in Exhibit A shall not of themselves be deemed to + be modifications of this License.) +

      7. Disclaimer of warranty

      +

      Covered code is provided under this license on an "as is" + basis, without warranty of any kind, either expressed or implied, including, without + limitation, warranties that the covered code is free of defects, merchantable, fit for a + particular purpose or non-infringing. The entire risk as to the quality and performance of + the covered code is with you. Should any covered code prove defective in any respect, you + (not the initial developer or any other contributor) assume the cost of any necessary + servicing, repair or correction. This disclaimer of warranty constitutes an essential part + of this license. No use of any covered code is authorized hereunder except under this + disclaimer. +

      8. Termination

      +

      8.1. This License and the rights granted hereunder will terminate + automatically if You fail to comply with terms herein and fail to cure such breach + within 30 days of becoming aware of the breach. All sublicenses to the Covered Code which + are properly granted shall survive any termination of this License. Provisions which, by + their nature, must remain in effect beyond the termination of this License shall survive. +

      8.2. If You initiate litigation by asserting a patent infringement + claim (excluding declatory judgment actions) against Initial Developer or a Contributor + (the Initial Developer or Contributor against whom You file such action is referred to + as "Participant") alleging that: +

        +
      1. such Participant's Contributor Version directly or indirectly + infringes any patent, then any and all rights granted by such Participant to You under + Sections 2.1 and/or 2.2 of this + License shall, upon 60 days notice from Participant terminate prospectively, unless if + within 60 days after receipt of notice You either: (i) agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future use of + Modifications made by such Participant, or (ii) withdraw Your litigation claim with + respect to the Contributor Version against such Participant. If within 60 days of + notice, a reasonable royalty and payment arrangement are not mutually agreed upon in + writing by the parties or the litigation claim is not withdrawn, the rights granted by + Participant to You under Sections 2.1 and/or + 2.2 automatically terminate at the expiration of the 60 day + notice period specified above. +
      2. any software, hardware, or device, other than such Participant's + Contributor Version, directly or indirectly infringes any patent, then any rights + granted to You by such Participant under Sections 2.1(b) + and 2.2(b) are revoked effective as of the date You first + made, used, sold, distributed, or had made, Modifications made by that Participant. +
      +

      8.3. If You assert a patent infringement claim against Participant + alleging that such Participant's Contributor Version directly or indirectly infringes + any patent where such claim is resolved (such as by license or settlement) prior to the + initiation of patent infringement litigation, then the reasonable value of the licenses + granted by such Participant under Sections 2.1 or + 2.2 shall be taken into account in determining the amount or + value of any payment or license. +

      8.4. In the event of termination under Sections + 8.1 or 8.2 above, all end user + license agreements (excluding distributors and resellers) which have been validly + granted by You or any distributor hereunder prior to termination shall survive + termination. +

      9. Limitation of liability

      +

      Under no circumstances and under no legal theory, whether + tort (including negligence), contract, or otherwise, shall you, the initial developer, + any other contributor, or any distributor of covered code, or any supplier of any of + such parties, be liable to any person for any indirect, special, incidental, or + consequential damages of any character including, without limitation, damages for loss + of goodwill, work stoppage, computer failure or malfunction, or any and all other + commercial damages or losses, even if such party shall have been informed of the + possibility of such damages. This limitation of liability shall not apply to liability + for death or personal injury resulting from such party's negligence to the extent + applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion + or limitation of incidental or consequential damages, so this exclusion and limitation + may not apply to you. +

      10. U.S. government end users

      +

      The Covered Code is a "commercial item," as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of + "commercial computer software" and "commercial computer software documentation," as such + terms are used in 48 C.F.R. 12.212 (Sept. + 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein. +

      11. Miscellaneous

      +

      This License represents the complete agreement concerning subject matter hereof. If + any provision of this License is held to be unenforceable, such provision shall be + reformed only to the extent necessary to make it enforceable. This License shall be + governed by California law provisions (except to the extent applicable law, if any, + provides otherwise), excluding its conflict-of-law provisions. With respect to + disputes in which at least one party is a citizen of, or an entity chartered or + registered to do business in the United States of America, any litigation relating to + this License shall be subject to the jurisdiction of the Federal Courts of the + Northern District of California, with venue lying in Santa Clara County, California, + with the losing party responsible for costs, including without limitation, court + costs and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is expressly + excluded. Any law or regulation which provides that the language of a contract + shall be construed against the drafter shall not apply to this License. +

      12. Responsibility for claims

      +

      As between Initial Developer and the Contributors, each party is responsible for + claims and damages arising, directly or indirectly, out of its utilization of rights + under this License and You agree to work with Initial Developer and Contributors to + distribute such responsibility on an equitable basis. Nothing herein is intended or + shall be deemed to constitute any admission of liability. +

      13. Multiple-licensed code

      +

      Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed". "Multiple-Licensed" means that the Initial Developer permits + you to utilize portions of the Covered Code under Your choice of the MPL + or the alternative licenses, if any, specified by the Initial Developer in the file + described in Exhibit A. +

      Exhibit A - Mozilla Public License.

      +
      "The contents of this file are subject to the Mozilla Public License
      +Version 1.1 (the "License"); you may not use this file except in
      +compliance with the License. You may obtain a copy of the License at
      +http://www.mozilla.org/MPL/
      +
      +Software distributed under the License is distributed on an "AS IS"
      +basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
      +License for the specific language governing rights and limitations
      +under the License.
      +
      +The Original Code is JTransforms.
      +
      +The Initial Developer of the Original Code is
      +Piotr Wendykier, Emory University.
      +Portions created by the Initial Developer are Copyright (C) 2007-2009
      +the Initial Developer. All Rights Reserved.
      +
      +Alternatively, the contents of this file may be used under the terms of
      +either the GNU General Public License Version 2 or later (the "GPL"), or
      +the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
      +in which case the provisions of the GPL or the LGPL are applicable instead
      +of those above. If you wish to allow use of your version of this file only
      +under the terms of either the GPL or the LGPL, and not to allow others to
      +use your version of this file under the terms of the MPL, indicate your
      +decision by deleting the provisions above and replace them with the notice
      +and other provisions required by the GPL or the LGPL. If you do not delete
      +the provisions above, a recipient may use your version of this file under
      +the terms of any one of the MPL, the GPL or the LGPL.
      +

      NOTE: The text of this Exhibit A may differ slightly from the text of + the notices in the Source Code files of the Original Code. You should + use the text of this Exhibit A rather than the text found in the + Original Code Source Code for Your Modifications. + +

      \ No newline at end of file diff --git a/licenses/LICENSE-kryo.txt b/licenses-binary/LICENSE-kryo.txt similarity index 100% rename from licenses/LICENSE-kryo.txt rename to licenses-binary/LICENSE-kryo.txt diff --git a/licenses-binary/LICENSE-leveldbjni.txt b/licenses-binary/LICENSE-leveldbjni.txt new file mode 100644 index 0000000000000..b4dabb9174c6d --- /dev/null +++ b/licenses-binary/LICENSE-leveldbjni.txt @@ -0,0 +1,27 @@ +Copyright (c) 2011 FuseSource Corp. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of FuseSource Corp. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-machinist.txt b/licenses-binary/LICENSE-machinist.txt new file mode 100644 index 0000000000000..68cc3a3e3a9c4 --- /dev/null +++ b/licenses-binary/LICENSE-machinist.txt @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Erik Osheim, Tom Switzer + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-matchMedia-polyfill.txt b/licenses-binary/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses-binary/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-minlog.txt b/licenses-binary/LICENSE-minlog.txt similarity index 100% rename from licenses/LICENSE-minlog.txt rename to licenses-binary/LICENSE-minlog.txt diff --git a/licenses-binary/LICENSE-modernizr.txt b/licenses-binary/LICENSE-modernizr.txt new file mode 100644 index 0000000000000..2bf24b9b9f848 --- /dev/null +++ b/licenses-binary/LICENSE-modernizr.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-netlib.txt b/licenses-binary/LICENSE-netlib.txt similarity index 100% rename from licenses/LICENSE-netlib.txt rename to licenses-binary/LICENSE-netlib.txt diff --git a/licenses/LICENSE-paranamer.txt b/licenses-binary/LICENSE-paranamer.txt similarity index 100% rename from licenses/LICENSE-paranamer.txt rename to licenses-binary/LICENSE-paranamer.txt diff --git a/licenses/LICENSE-jpmml-model.txt b/licenses-binary/LICENSE-pmml-model.txt similarity index 100% rename from licenses/LICENSE-jpmml-model.txt rename to licenses-binary/LICENSE-pmml-model.txt diff --git a/licenses/LICENSE-protobuf.txt b/licenses-binary/LICENSE-protobuf.txt similarity index 100% rename from licenses/LICENSE-protobuf.txt rename to licenses-binary/LICENSE-protobuf.txt diff --git a/licenses-binary/LICENSE-py4j.txt b/licenses-binary/LICENSE-py4j.txt new file mode 100644 index 0000000000000..70af3e69ed67a --- /dev/null +++ b/licenses-binary/LICENSE-py4j.txt @@ -0,0 +1,27 @@ +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + diff --git a/licenses/LICENSE-pyrolite.txt b/licenses-binary/LICENSE-pyrolite.txt similarity index 100% rename from licenses/LICENSE-pyrolite.txt rename to licenses-binary/LICENSE-pyrolite.txt diff --git a/licenses/LICENSE-reflectasm.txt b/licenses-binary/LICENSE-reflectasm.txt similarity index 100% rename from licenses/LICENSE-reflectasm.txt rename to licenses-binary/LICENSE-reflectasm.txt diff --git a/licenses-binary/LICENSE-respond.txt b/licenses-binary/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses-binary/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-sbt-launch-lib.txt b/licenses-binary/LICENSE-sbt-launch-lib.txt new file mode 100644 index 0000000000000..3b9156baaab78 --- /dev/null +++ b/licenses-binary/LICENSE-sbt-launch-lib.txt @@ -0,0 +1,26 @@ +// Generated from http://www.opensource.org/licenses/bsd-license.php +Copyright (c) 2011, Paul Phillips. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-scala.txt b/licenses-binary/LICENSE-scala.txt similarity index 100% rename from licenses/LICENSE-scala.txt rename to licenses-binary/LICENSE-scala.txt diff --git a/licenses-binary/LICENSE-scopt.txt b/licenses-binary/LICENSE-scopt.txt new file mode 100644 index 0000000000000..e92e9b592fba0 --- /dev/null +++ b/licenses-binary/LICENSE-scopt.txt @@ -0,0 +1,9 @@ +This project is licensed under the MIT license. + +Copyright (c) scopt contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-slf4j.txt b/licenses-binary/LICENSE-slf4j.txt similarity index 100% rename from licenses/LICENSE-slf4j.txt rename to licenses-binary/LICENSE-slf4j.txt diff --git a/licenses-binary/LICENSE-sorttable.js.txt b/licenses-binary/LICENSE-sorttable.js.txt new file mode 100644 index 0000000000000..b31a5b206bf40 --- /dev/null +++ b/licenses-binary/LICENSE-sorttable.js.txt @@ -0,0 +1,16 @@ +Copyright (c) 1997-2007 Stuart Langridge + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/licenses/LICENSE-spire.txt b/licenses-binary/LICENSE-spire.txt similarity index 100% rename from licenses/LICENSE-spire.txt rename to licenses-binary/LICENSE-spire.txt diff --git a/licenses-binary/LICENSE-vis.txt b/licenses-binary/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses-binary/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/licenses/LICENSE-xmlenc.txt b/licenses-binary/LICENSE-xmlenc.txt similarity index 100% rename from licenses/LICENSE-xmlenc.txt rename to licenses-binary/LICENSE-xmlenc.txt diff --git a/licenses/LICENSE-zstd-jni.txt b/licenses-binary/LICENSE-zstd-jni.txt similarity index 100% rename from licenses/LICENSE-zstd-jni.txt rename to licenses-binary/LICENSE-zstd-jni.txt diff --git a/licenses/LICENSE-zstd.txt b/licenses-binary/LICENSE-zstd.txt similarity index 100% rename from licenses/LICENSE-zstd.txt rename to licenses-binary/LICENSE-zstd.txt diff --git a/licenses/LICENSE-CC0.txt b/licenses/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-SnapTree.txt b/licenses/LICENSE-SnapTree.txt deleted file mode 100644 index a538825d89ec5..0000000000000 --- a/licenses/LICENSE-SnapTree.txt +++ /dev/null @@ -1,35 +0,0 @@ -SNAPTREE LICENSE - -Copyright (c) 2009-2012 Stanford University, unless otherwise specified. -All rights reserved. - -This software was developed by the Pervasive Parallelism Laboratory of -Stanford University, California, USA. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Stanford University nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. diff --git a/licenses/LICENSE-bootstrap.txt b/licenses/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses/LICENSE-boto.txt b/licenses/LICENSE-boto.txt deleted file mode 100644 index 7bba0cd9e10a4..0000000000000 --- a/licenses/LICENSE-boto.txt +++ /dev/null @@ -1,20 +0,0 @@ -Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, dis- -tribute, sublicense, and/or sell copies of the Software, and to permit -persons to whom the Software is furnished to do so, subject to the fol- -lowing conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- -ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-datatables.txt b/licenses/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-graphlib-dot.txt b/licenses/LICENSE-graphlib-dot.txt index c9e18cd562423..4864fe05e9803 100644 --- a/licenses/LICENSE-graphlib-dot.txt +++ b/licenses/LICENSE-graphlib-dot.txt @@ -1,4 +1,4 @@ -Copyright (c) 2012-2013 Chris Pettitt +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses/LICENSE-jbcrypt.txt b/licenses/LICENSE-jbcrypt.txt deleted file mode 100644 index d332534c06356..0000000000000 --- a/licenses/LICENSE-jbcrypt.txt +++ /dev/null @@ -1,17 +0,0 @@ -jBCrypt is subject to the following license: - -/* - * Copyright (c) 2006 Damien Miller - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ diff --git a/licenses/LICENSE-join.txt b/licenses/LICENSE-join.txt new file mode 100644 index 0000000000000..1d916090e4ea0 --- /dev/null +++ b/licenses/LICENSE-join.txt @@ -0,0 +1,30 @@ +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jquery.txt b/licenses/LICENSE-jquery.txt index e1dd696d3b6cc..45930542204fb 100644 --- a/licenses/LICENSE-jquery.txt +++ b/licenses/LICENSE-jquery.txt @@ -1,9 +1,20 @@ -The MIT License (MIT) +Copyright JS Foundation and other contributors, https://js.foundation/ -Copyright (c) +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-json-formatter.txt b/licenses/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-matchMedia-polyfill.txt b/licenses/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-postgresql.txt b/licenses/LICENSE-postgresql.txt deleted file mode 100644 index 515bf9af4d432..0000000000000 --- a/licenses/LICENSE-postgresql.txt +++ /dev/null @@ -1,24 +0,0 @@ -PostgreSQL Database Management System -(formerly known as Postgres, then as Postgres95) - -Portions Copyright (c) 1996-2010, PostgreSQL Global Development Group - -Portions Copyright (c) 1994, The Regents of the University of California - -Permission to use, copy, modify, and distribute this software and its -documentation for any purpose, without fee, and without a written agreement -is hereby granted, provided that the above copyright notice and this -paragraph and the following two paragraphs appear in all copies. - -IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR -DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING -LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS -DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - -THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, -INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY -AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS -ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO -PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. - diff --git a/licenses/LICENSE-respond.txt b/licenses/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-scalacheck.txt b/licenses/LICENSE-scalacheck.txt deleted file mode 100644 index cb8f97842f4c4..0000000000000 --- a/licenses/LICENSE-scalacheck.txt +++ /dev/null @@ -1,32 +0,0 @@ -ScalaCheck LICENSE - -Copyright (c) 2007-2015, Rickard Nilsson -All rights reserved. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of the author nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-vis.txt b/licenses/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister index 5e5484fd8784d..f14431d50feec 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -1,2 +1,4 @@ org.apache.spark.ml.regression.InternalLinearRegressionModelWriter -org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter +org.apache.spark.ml.clustering.InternalKMeansModelWriter +org.apache.spark.ml.clustering.PMMLKMeansModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 771cd4fe91dcf..8a57bfc029d14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -96,10 +97,14 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") override def setSeed(value: Long): this.type = set(seed, value) - override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { + override protected def train( + dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + instr.logNumClasses(numClasses) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -110,29 +115,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(params: _*) + instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeClassificationModel] } /** (private[ml]) Train a decision tree on an RDD */ private[ml] def train(data: RDD[LabeledPoint], - oldStrategy: OldStrategy): DecisionTreeClassificationModel = { - val instr = Instrumentation.create(this, data) - instr.logParams(params: _*) + oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(data) + instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeClassificationModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -279,7 +282,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) val model = new DecisionTreeClassificationModel(metadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c0255103bc313..33acd9914073f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -31,9 +31,9 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -146,12 +146,22 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - override protected def train(dataset: Dataset[_]): GBTClassificationModel = { + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + + override protected def train( + dataset: Dataset[_]): GBTClassificationModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. - val oldDataset: RDD[LabeledPoint] = + val convert2LabeledPoint = (dataset: Dataset[_]) => { dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + @@ -159,7 +169,18 @@ class GBTClassifier @Since("1.4.0") ( s" GBTClassifier currently only supports binary classification.") LabeledPoint(label, features) } - val numFeatures = oldDataset.first().features.size + } + + val (trainDataset, validationDataset) = if (withValidation) { + ( + convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))), + convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (convert2LabeledPoint(dataset), null) + } + + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -169,18 +190,23 @@ class GBTClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, + validationIndicatorCol) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) - val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) - instr.logSuccess(m) - m + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) + } + + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.1") @@ -334,6 +360,21 @@ class GBTClassificationModel private[ml]( // hard coded loss, which is not meant to be changed in the model private val loss = getOldLossType + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss, + OldAlgo.Classification + ) + } + @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } @@ -379,14 +420,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 8f950cd28c3aa..1b5c02fc9a576 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -162,7 +163,7 @@ class LinearSVC @Since("2.2.0") ( @Since("2.2.0") override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) - override protected def train(dataset: Dataset[_]): LinearSVCModel = { + override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr => val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -170,8 +171,9 @@ class LinearSVC @Since("2.2.0") ( Instance(label, weight, features) } - val instr = Instrumentation.create(this, instances) - instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth) val (summarizer, labelSummarizer) = { @@ -187,6 +189,9 @@ class LinearSVC @Since("2.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNumExamples(summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -209,7 +214,7 @@ class LinearSVC @Since("2.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -246,7 +251,7 @@ class LinearSVC @Since("2.2.0") ( bcFeaturesStd.destroy(blocking = false) if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -273,9 +278,7 @@ class LinearSVC @Since("2.2.0") ( (Vectors.dense(coefficientArray), intercept, scaledObjectiveHistory.result()) } - val model = copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) - instr.logSuccess(model) - model + copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) } } @@ -377,7 +380,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { val Row(coefficients: Vector, intercept: Double) = data.select("coefficients", "intercept").head() val model = new LinearSVCModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ee4b01058c75c..6f0804f0c8e4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -35,6 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer @@ -490,7 +491,7 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train( dataset: Dataset[_], - handlePersistence: Boolean): LogisticRegressionModel = { + handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr => val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -500,8 +501,9 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, instances) - instr.logParams(regParam, elasticNetParam, standardization, threshold, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) val (summarizer, labelSummarizer) = { @@ -517,7 +519,7 @@ class LogisticRegression @Since("1.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } - instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNumExamples(summarizer.count) instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) @@ -816,7 +818,7 @@ class LogisticRegression @Since("1.2.0") ( if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -905,8 +907,6 @@ class LogisticRegression @Since("1.2.0") ( objectiveHistory) } model.setSummary(Some(logRegSummary)) - instr.logSuccess(model) - model } @Since("1.4.0") @@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] ( */ @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + + override def toString: String = { + s"LogisticRegressionModel: " + + s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" + } } @@ -1270,7 +1275,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { numClasses, isMultinomial) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -1479,7 +1484,7 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * Convenient method for casting to binary logistic regression summary. - * This method will throws an Exception if the summary is not a binary summary. + * This method will throw an Exception if the summary is not a binary summary. */ @Since("2.3.0") def asBinary: BinaryLogisticRegressionSummary = this match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index af2e4699924e5..4feddce1d9f2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -23,12 +23,13 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.OneHotEncoderModel import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.util.Instrumentation.instrumented +import org.apache.spark.sql.{Dataset, Row} /** Params for Multilayer Perceptron. */ private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams @@ -102,36 +103,6 @@ private[classification] trait MultilayerPerceptronParams extends ProbabilisticCl solver -> LBFGS, stepSize -> 0.03) } -/** Label to vector converter. */ -private object LabelConverter { - // TODO: Use OneHotEncoder instead - /** - * Encodes a label as a vector. - * Returns a vector of given length with zeroes at all positions - * and value 1.0 at the position that corresponds to the label. - * - * @param labeledPoint labeled point - * @param labelCount total number of labels - * @return pair of features and vector encoding of a label - */ - def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { - val output = Array.fill(labelCount)(0.0) - output(labeledPoint.label.toInt) = 1.0 - (labeledPoint.features, Vectors.dense(output)) - } - - /** - * Converts a vector to a label. - * Returns the position of the maximal element of a vector. - * - * @param output label encoded with a vector - * @return label - */ - def decodeLabel(output: Vector): Double = { - output.argmax.toDouble - } -} - /** * Classifier trainer based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. @@ -230,9 +201,11 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = { - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, layers, maxIter, tol, + override protected def train( + dataset: Dataset[_]): MultilayerPerceptronClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, layers, maxIter, tol, blockSize, solver, stepSize, seed) val myLayers = $(layers) @@ -240,8 +213,18 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( instr.logNumClasses(labels) instr.logNumFeatures(myLayers.head) - val lpData = extractLabeledPoints(dataset) - val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) + // One-hot encoding for labels using OneHotEncoderModel. + // As we already know the length of encoding, we skip fitting and directly create + // the model. + val encodedLabelCol = "_encoded" + $(labelCol) + val encodeModel = new OneHotEncoderModel(uid, Array(labels)) + .setInputCols(Array($(labelCol))) + .setOutputCols(Array(encodedLabelCol)) + .setDropLast(false) + val encodedDataset = encodeModel.transform(dataset) + val data = encodedDataset.select($(featuresCol), encodedLabelCol).rdd.map { + case Row(features: Vector, encodedLabel: Vector) => (features, encodedLabel) + } val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, softmaxOnTop = true) val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) if (isDefined(initialWeights)) { @@ -264,10 +247,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( } trainer.setStackSize($(blockSize)) val mlpModel = trainer.train(data) - val model = new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) - - instr.logSuccess(model) - model + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) } } @@ -323,7 +303,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( * This internal method is used to implement `transform()` and output [[predictionCol]]. */ override def predict(features: Vector): Double = { - LabelConverter.decodeLabel(mlpModel.predict(features)) + mlpModel.predict(features).argmax.toDouble } @Since("1.5.0") @@ -388,7 +368,7 @@ object MultilayerPerceptronClassificationModel val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0293e03d47435..51495c1a74e69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit} @@ -125,9 +126,12 @@ class NaiveBayes @Since("1.5.0") ( */ private[spark] def trainWithLabelCheck( dataset: Dataset[_], - positiveLabel: Boolean): NaiveBayesModel = { + positiveLabel: Boolean): NaiveBayesModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) if (positiveLabel && isDefined(thresholds)) { val numClasses = getNumClasses(dataset) + instr.logNumClasses(numClasses) require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") @@ -146,8 +150,7 @@ class NaiveBayes @Since("1.5.0") ( } } - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, probabilityCol, modelType, smoothing, thresholds) val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size @@ -159,19 +162,21 @@ class NaiveBayes @Since("1.5.0") ( // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) - }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( + }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))( seqOp = { - case ((weightSum: Double, featureSum: DenseVector), (weight, features)) => + case ((weightSum, featureSum, count), (weight, features)) => requireValues(features) BLAS.axpy(weight, features, featureSum) - (weightSum + weight, featureSum) + (weightSum + weight, featureSum, count + 1) }, combOp = { - case ((weightSum1, featureSum1), (weightSum2, featureSum2)) => + case ((weightSum1, featureSum1, count1), (weightSum2, featureSum2, count2)) => BLAS.axpy(1.0, featureSum2, featureSum1) - (weightSum1 + weightSum2, featureSum1) + (weightSum1 + weightSum2, featureSum1, count1 + count2) }).collect().sortBy(_._1) + val numSamples = aggregated.map(_._2._3).sum + instr.logNumExamples(numSamples) val numLabels = aggregated.length instr.logNumClasses(numLabels) val numDocuments = aggregated.map(_._2._1).sum @@ -183,7 +188,7 @@ class NaiveBayes @Since("1.5.0") ( val lambda = $(smoothing) val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 - aggregated.foreach { case (label, (n, sumTermFreqs)) => + aggregated.foreach { case (label, (n, sumTermFreqs, _)) => labelArray(i) = label piArray(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = $(modelType) match { @@ -203,9 +208,7 @@ class NaiveBayes @Since("1.5.0") ( val pi = Vectors.dense(piArray) val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) - val model = new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) - instr.logSuccess(model) - model + new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) } @Since("1.5.0") @@ -407,7 +410,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 5348d882cfd67..1835a91775e0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -36,6 +36,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -289,7 +290,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) } val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) - DefaultParamsReader.getAndSetParams(ovrModel, metadata) + metadata.getAndSetParams(ovrModel) ovrModel.set("classifier", classifier) ovrModel } @@ -362,11 +363,12 @@ final class OneVsRest @Since("1.4.0") ( } @Since("2.0.0") - override def fit(dataset: Dataset[_]): OneVsRestModel = { + override def fit(dataset: Dataset[_]): OneVsRestModel = instrumented { instr => transformSchema(dataset.schema) - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, parallelism) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) // determine number of classes either from metadata if provided, or via computation. @@ -383,7 +385,7 @@ final class OneVsRest @Since("1.4.0") ( getClassifier match { case _: HasWeightCol => true case c => - logWarning(s"weightCol is ignored, as it is not supported by $c now.") + instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.") false } } @@ -440,7 +442,6 @@ final class OneVsRest @Since("1.4.0") ( case attr: Attribute => attr } val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) - instr.logSuccess(model) copyValues(model) } @@ -484,7 +485,7 @@ object OneVsRest extends MLReadable[OneVsRest] { override def load(path: String): OneVsRest = { val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val ovr = new OneVsRest(metadata.uid) - DefaultParamsReader.getAndSetParams(ovr, metadata) + metadata.getAndSetParams(ovr) ovr.setClassifier(classifier) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bb972e9706fc1..94887ac346fec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -115,7 +116,10 @@ class RandomForestClassifier @Since("1.4.0") ( override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { + override protected def train( + dataset: Dataset[_]): RandomForestClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -130,8 +134,7 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, + instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) @@ -140,9 +143,9 @@ class RandomForestClassifier @Since("1.4.0") ( .map(_.asInstanceOf[DecisionTreeClassificationModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) - instr.logSuccess(m) - m + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) + new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) } @Since("1.4.1") @@ -319,14 +322,14 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica case (treeMetadata, root) => val tree = new DecisionTreeClassificationModel(treeMetadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f7c422dc0faea..8904193cae94c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -22,17 +22,16 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -75,7 +74,7 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -113,7 +112,8 @@ class BisectingKMeansModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -132,9 +132,9 @@ class BisectingKMeansModel private[ml] ( */ @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } - parentModel.computeCost(data.map(OldVectors.fromML)) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) + parentModel.computeCost(data) } @Since("2.0.0") @@ -193,7 +193,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) val model = new BisectingKMeansModel(metadata.uid, mllibModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -258,14 +258,14 @@ class BisectingKMeans @Since("2.0.0") ( def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): BisectingKMeansModel = { + override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr => transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - val instr = Instrumentation.create(this, rdd) - instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, k, maxIter, seed, + minDivisibleClusterSize, distanceMeasure) val bkm = new MLlibBisectingKMeans() .setK($(k)) @@ -273,13 +273,13 @@ class BisectingKMeans @Since("2.0.0") ( .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) .setDistanceMeasure($(distanceMeasure)) - val parentModel = bkm.run(rdd) + val parentModel = bkm.run(rdd, Some(instr)) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) + instr.logNamedValue("clusterSizes", summary.clusterSizes) + instr.logNumFeatures(model.clusterCenters.head.size) model.setSummary(Some(summary)) - instr.logSuccess(model) - model } @Since("2.0.0") @@ -305,6 +305,7 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.1.0") @Experimental @@ -312,4 +313,5 @@ class BisectingKMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala index 44e832b058b62..7da4c43a1abf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.sql.{DataFrame, Row} /** @@ -28,13 +28,15 @@ import org.apache.spark.sql.{DataFrame, Row} * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Experimental class ClusteringSummary private[clustering] ( @transient val predictions: DataFrame, val predictionCol: String, val featuresCol: String, - val k: Int) extends Serializable { + val k: Int, + @Since("2.4.0") val numIter: Int) extends Serializable { /** * Cluster centers of the transformed data. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f19ad7a5a6938..88abc1605d69f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -29,11 +29,12 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -63,7 +64,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } @@ -109,8 +110,9 @@ class GaussianMixtureModel private[ml] ( transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) - .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + dataset + .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -233,7 +235,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { } val model = new GaussianMixtureModel(metadata.uid, weights, gaussians) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -334,13 +336,14 @@ class GaussianMixture @Since("2.0.0") ( private val numSamples = 5 @Since("2.0.0") - override def fit(dataset: Dataset[_]): GaussianMixtureModel = { + override def fit(dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map { + val instances = dataset + .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() @@ -350,8 +353,9 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val instr = Instrumentation.create(this, instances) - instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) instr.logNumFeatures(numFeatures) val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians( @@ -382,6 +386,11 @@ class GaussianMixture @Since("2.0.0") ( bcWeights.destroy(blocking = false) bcGaussians.destroy(blocking = false) + if (iter == 0) { + val numSamples = sums.count + instr.logNumExamples(numSamples) + } + /* Create new distributions based on the partial assignments (often referred to as the "M" step in literature) @@ -414,6 +423,7 @@ class GaussianMixture @Since("2.0.0") ( iter += 1 } + instances.unpersist(false) val gaussianDists = gaussians.map { case (mean, covVec) => val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) new MultivariateGaussian(mean, cov) @@ -421,10 +431,10 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iter) + instr.logNamedValue("logLikelihood", logLikelihood) + instr.logNamedValue("clusterSizes", summary.clusterSizes) model.setSummary(Some(summary)) - instr.logSuccess(model) - model } @Since("2.0.0") @@ -683,6 +693,7 @@ private class ExpectationAggregator( * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. * @param logLikelihood Total log-likelihood for this model on the given data. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -692,8 +703,9 @@ class GaussianMixtureSummary private[clustering] ( @Since("2.0.0") val probabilityCol: String, featuresCol: String, k: Int, - @Since("2.2.0") val logLikelihood: Double) - extends ClusteringSummary(predictions, predictionCol, featuresCol, k) { + @Since("2.2.0") val logLikelihood: Double, + numIter: Int) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) { /** * Probability of each cluster. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 987a4285ebad4..498310d6644e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,21 +17,24 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -90,7 +93,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -103,8 +106,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + private[clustering] val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -123,8 +126,11 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("1.5.0") @@ -140,26 +146,28 @@ class KMeansModel private[ml] ( /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. + * + * @deprecated This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator + * instead. You can also get the cost on the training dataset in the summary. */ - // TODO: Replace the temp fix when we have proper evaluators defined for clustering. + @deprecated("This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator " + + "instead. You can also get the cost on the training dataset in the summary.", "2.4.0") @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) parentModel.computeCost(data) } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[KMeansModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -185,6 +193,47 @@ class KMeansModel private[ml] ( } } +/** Helper class for storing model data */ +private case class ClusterData(clusterIdx: Int, clusterCenter: Vector) + + +/** A writer for KMeans that handles the "internal" (or default) format */ +private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map { + case (center, idx) => + ClusterData(idx, center) + } + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for KMeans that handles the "pmml" format */ +private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + instance.parentModel.toPMML(sc, path) + } +} + + @Since("1.6.0") object KMeansModel extends MLReadable[KMeansModel] { @@ -194,30 +243,12 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) - /** Helper class for storing model data */ - private case class Data(clusterIdx: Int, clusterCenter: Vector) - /** * We store all cluster centers in a single row and use this class to store model data by * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. */ private case class OldData(clusterCenters: Array[OldVector]) - /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: cluster centers - val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => - Data(idx, center) - } - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) - } - } - private class KMeansModelReader extends MLReader[KMeansModel] { /** Checked against metadata when loading model */ @@ -232,14 +263,14 @@ object KMeansModel extends MLReadable[KMeansModel] { val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { - val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -306,20 +337,19 @@ class KMeans @Since("1.5.0") ( def setSeed(value: Long): this.type = set(seed, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): KMeansModel = { + override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, instances) - instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) @@ -332,10 +362,15 @@ class KMeans @Since("1.5.0") ( val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + parentModel.numIter, + parentModel.trainingCost) model.setSummary(Some(summary)) - instr.logSuccess(model) + instr.logNamedValue("clusterSizes", summary.clusterSizes) if (handlePersistence) { instances.unpersist() } @@ -363,6 +398,9 @@ object KMeans extends DefaultParamsReadable[KMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. + * @param trainingCost K-means cost (sum of squared distances to the nearest centroid for all + * points in the training dataset). This is equivalent to sklearn's inertia. */ @Since("2.0.0") @Experimental @@ -370,4 +408,7 @@ class KMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int, + @Since("2.4.0") val trainingCost: Double) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 4bab670cc159f..50867f776c522 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, @@ -43,7 +44,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType} import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils @@ -345,7 +346,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } @@ -366,7 +367,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM private object LDAParams { /** - * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. * * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with @@ -391,7 +392,7 @@ private object LDAParams { s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } case _ => // 2.0+ - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) } } } @@ -461,7 +462,8 @@ abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() + dataset.withColumn($(topicDistributionCol), + t(DatasetUtils.columnToVector(dataset, getFeaturesCol))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") @@ -568,10 +570,14 @@ abstract class LDAModel private[ml] ( class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, - @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel, + private[clustering] val oldLocalModel_ : OldLocalLDAModel, sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { + override private[clustering] def oldLocalModel: OldLocalLDAModel = { + oldLocalModel_.setSeed(getSeed) + } + @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) @@ -891,11 +897,12 @@ class LDA @Since("1.6.0") ( override def copy(extra: ParamMap): LDA = defaultCopy(extra) @Since("2.0.0") - override def fit(dataset: Dataset[_]): LDAModel = { + override def fit(dataset: Dataset[_]): LDAModel = instrumented { instr => transformSchema(dataset.schema, logging = true) - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, topicDistributionCol, k, maxIter, subsamplingRate, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, topicDistributionCol, k, maxIter, subsamplingRate, checkpointInterval, keepLastCheckpoint, optimizeDocConcentration, topicConcentration, learningDecay, optimizer, learningOffset, seed) @@ -918,9 +925,7 @@ class LDA @Since("1.6.0") ( } instr.logNumFeatures(newModel.vocabSize) - val model = copyValues(newModel).setParent(this) - instr.logSuccess(model) - model + copyValues(newModel).setParent(this) } @Since("1.6.0") @@ -938,7 +943,7 @@ object LDA extends MLReadable[LDA] { featuresCol: String): RDD[(Long, OldVector)] = { dataset .withColumn("docId", monotonically_increasing_id()) - .select("docId", featuresCol) + .select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol)) .rdd .map { case Row(docId: Long, features: Vector) => (docId, OldVectors.fromML(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa947..1b9a3499947d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -18,21 +18,20 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ /** * Common params for PowerIterationClustering */ private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter - with HasPredictionCol { + with HasWeightCol { /** * The number of clusters to create (k). Must be > 1. Default: 2. @@ -66,62 +65,33 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has def getInitMode: String = $(initMode) /** - * Param for the name of the input column for vertex IDs. - * Default: "id" + * Param for the name of the input column for source vertex IDs. + * Default: "src" * @group param */ @Since("2.4.0") - val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.", (value: String) => value.nonEmpty) - setDefault(idCol, "id") - - /** @group getParam */ - @Since("2.4.0") - def getIdCol: String = getOrDefault(idCol) - - /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "neighbors" - * @group param - */ - @Since("2.4.0") - val neighborsCol = new Param[String](this, "neighborsCol", - "Name of the input column for neighbors in the adjacency list representation.", - (value: String) => value.nonEmpty) - - setDefault(neighborsCol, "neighbors") - /** @group getParam */ @Since("2.4.0") - def getNeighborsCol: String = $(neighborsCol) + def getSrcCol: String = getOrDefault(srcCol) /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "similarities" + * Name of the input column for destination vertex IDs. + * Default: "dst" * @group param */ @Since("2.4.0") - val similaritiesCol = new Param[String](this, "similaritiesCol", - "Name of the input column for neighbors in the adjacency list representation.", + val dstCol = new Param[String](this, "dstCol", + "Name of the input column for destination vertex IDs.", (value: String) => value.nonEmpty) - setDefault(similaritiesCol, "similarities") - /** @group getParam */ @Since("2.4.0") - def getSimilaritiesCol: String = $(similaritiesCol) + def getDstCol: String = $(dstCol) - protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) - SchemaUtils.checkColumnTypes(schema, $(neighborsCol), - Seq(ArrayType(IntegerType, containsNull = false), - ArrayType(LongType, containsNull = false))) - SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), - Seq(ArrayType(FloatType, containsNull = false), - ArrayType(DoubleType, containsNull = false))) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - } + setDefault(srcCol -> "src", dstCol -> "dst") } /** @@ -131,21 +101,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has * PIC finds a very low-dimensional embedding of a dataset using truncated power * iteration on a normalized pair-wise similarity matrix of the data. * - * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix - * is a symmetric matrix whose entries are non-negative similarities between items. - * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: - * - `idCol`: vertex ID - * - `neighborsCol`: neighbors of vertex in `idCol` - * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex - * in `idCol` and each neighbor in `neighborsCol` - * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` - * containing the cluster assignment in `[0,k)` for each row (vertex). - * - * Notes: - * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. - * Transform runs the iterative PIC algorithm to cluster the whole input dataset. - * - Input validation: This validates that similarities are non-negative but does NOT validate - * that the input matrix is symmetric. + * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the + * PowerIterationClustering algorithm. * * @see * Spectral clustering (Wikipedia) @@ -154,7 +111,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has @Experimental class PowerIterationClustering private[clustering] ( @Since("2.4.0") override val uid: String) - extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + extends PowerIterationClusteringParams with DefaultParamsWritable { setDefault( k -> 2, @@ -164,10 +121,6 @@ class PowerIterationClustering private[clustering] ( @Since("2.4.0") def this() = this(Identifiable.randomUID("PowerIterationClustering")) - /** @group setParam */ - @Since("2.4.0") - def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** @group setParam */ @Since("2.4.0") def setK(value: Int): this.type = set(k, value) @@ -182,66 +135,56 @@ class PowerIterationClustering private[clustering] ( /** @group setParam */ @Since("2.4.0") - def setIdCol(value: String): this.type = set(idCol, value) + def setSrcCol(value: String): this.type = set(srcCol, value) /** @group setParam */ @Since("2.4.0") - def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + def setDstCol(value: String): this.type = set(dstCol, value) /** @group setParam */ @Since("2.4.0") - def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + def setWeightCol(value: String): this.type = set(weightCol, value) + /** + * Run the PIC algorithm and returns a cluster assignment for each input vertex. + * + * @param dataset A dataset with columns src, dst, weight representing the affinity matrix, + * which is the matrix A in the PIC paper. Suppose the src column value is i, + * the dst column value is j, the weight column value is similarity s,,ij,, + * which must be nonnegative. This is a symmetric matrix and hence + * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + * ignored, because we assume s,,ij,, = 0.0. + * + * @return A dataset that contains columns of vertex id and the corresponding cluster for the id. + * The schema of it will be: + * - id: Long + * - cluster: Int + */ @Since("2.4.0") - override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + def assignClusters(dataset: Dataset[_]): DataFrame = { + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { + lit(1.0) + } else { + col($(weightCol)).cast(DoubleType) + } - val sparkSession = dataset.sparkSession - val idColValue = $(idCol) - val rdd: RDD[(Long, Long, Double)] = - dataset.select( - col($(idCol)).cast(LongType), - col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), - col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) - ).rdd.flatMap { - case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => - require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + - s"equal to the the length of the neighbor similarity list. Row for ID " + - s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + - s"of length ${sims.length}.") - nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { - case (nbr, similarity) => (id, nbr, similarity) - } - } + SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType)) + val rdd: RDD[(Long, Long, Double)] = dataset.select( + col($(srcCol)).cast(LongType), + col($(dstCol)).cast(LongType), + w).rdd.map { + case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight) + } val algorithm = new MLlibPowerIterationClustering() .setK($(k)) .setInitializationMode($(initMode)) .setMaxIterations($(maxIter)) val model = algorithm.run(rdd) - val predictionsRDD: RDD[Row] = model.assignments.map { assignment => - Row(assignment.id, assignment.cluster) - } - - val predictionsSchema = StructType(Seq( - StructField($(idCol), LongType, nullable = false), - StructField($(predictionCol), IntegerType, nullable = false))) - val predictions = { - val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.schema($(idCol)).dataType match { - case _: LongType => - uncastPredictions - case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) - } - } - - dataset.join(predictions, $(idCol)) - } - - @Since("2.4.0") - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + import dataset.sparkSession.implicits._ + model.assignments.toDF } @Since("2.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 4353c46781e9d..5c1d1aebdc315 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -21,11 +21,10 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, - SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -107,15 +106,21 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol)) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) + val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) + val df = dataset.select(col($(predictionCol)), + vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata)) + ($(metricName), $(distanceMeasure)) match { case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol)) + df, $(predictionCol), $(featuresCol)) case ("silhouette", "cosine") => - CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) + CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), $(featuresCol)) + case (mn, dm) => + throw new IllegalArgumentException(s"No support for metric $mn, distance $dm") } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 41eaaf9679914..0554455a66d7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -82,14 +82,12 @@ class BucketedRandomProjectionLSHModel private[ml]( override def setOutputCol(value: String): this.type = super.set(outputCol, value) @Since("2.1.0") - override protected[ml] val hashFunction: Vector => Array[Vector] = { - key: Vector => { - val hashValues: Array[Double] = randUnitVectors.map({ - randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength)) - }) - // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 - hashValues.map(Vectors.dense(_)) - } + override protected[ml] def hashFunction(elems: Vector): Array[Vector] = { + val hashValues = randUnitVectors.map( + randUnitVector => Math.floor(BLAS.dot(elems, randUnitVector) / $(bucketLength)) + ) + // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 + hashValues.map(Vectors.dense(_)) } @Since("2.1.0") @@ -238,7 +236,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject val model = new BucketedRandomProjectionLSHModel(metadata.uid, randUnitVectors.rowIter.toArray) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f49c410cbcfe2..f99649f7fa164 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -217,8 +217,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -296,28 +294,6 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } - - private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 16abc4949dea3..dbfb199ccd58f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 9e0ed437e7bfc..dc8eb8261dbe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} @@ -317,7 +318,9 @@ class CountVectorizerModel( Vectors.sparse(dictBr.value.size, effectiveCounts) } - dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + val attrs = vocabulary.map(_ => new NumericAttribute).asInstanceOf[Array[Attribute]] + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + dataset.withColumn($(outputCol), vectorizer(col($(inputCol))), metadata) } @Since("1.5.0") @@ -363,7 +366,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { .head() val vocabulary = data.getAs[Seq[String]](0).toArray val model = new CountVectorizerModel(metadata.uid, vocabulary) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 682787a830113..32d98151bdcff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -69,7 +69,8 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + require(inputType.isInstanceOf[VectorUDT], + s"Input type must be ${(new VectorUDT).catalogString} but got ${inputType.catalogString}.") } override protected def outputDataType: DataType = new VectorUDT diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index d67e4819b161a..dc38ee326e5e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -208,8 +208,9 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme require(dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[BooleanType], - s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " + - s"Column $fieldName was $dataType") + s"FeatureHasher requires columns to be of ${NumericType.simpleString}, " + + s"${BooleanType.catalogString} or ${StringType.catalogString}. " + + s"Column $fieldName was ${dataType.catalogString}") } val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index db432b6fefaff..dbda5b8d8fd4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -104,7 +104,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.catalogString}.") val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 46a0730f5ddb8..58897cca4e5c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] { .select("idf") .head() val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 730ee9fc08db8..1c074e204ad99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] { val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 4ff1d0ef356f3..611f1b691b782 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -261,7 +261,8 @@ private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { */ def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { case d: Double => - assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + assert(numFeatures.length == 1, + s"${DoubleType.catalogString} columns should only contain one feature.") val numOutputCols = numFeatures.head if (numOutputCols > 1) { assert( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index a70931f783f45..b20852383a6ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -75,7 +75,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] * The hash function of LSH, mapping an input feature vector to multiple hash vectors. * @return The mapping of LSH function. */ - protected[ml] val hashFunction: Vector => Array[Vector] + protected[ml] def hashFunction(elems: Vector): Array[Vector] /** * Calculate the distance between two different keys using the distance metric corresponding @@ -97,7 +97,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val transformUDF = udf(hashFunction, DataTypes.createArrayType(new VectorUDT)) + val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT)) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 85f9732f79f67..90eceb0d61b40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 556848e45532d..21cde66d8db6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -60,18 +60,16 @@ class MinHashLSHModel private[ml]( override def setOutputCol(value: String): this.type = super.set(outputCol, value) @Since("2.1.0") - override protected[ml] val hashFunction: Vector => Array[Vector] = { - elems: Vector => { - require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") - val elemsList = elems.toSparse.indices.toList - val hashValues = randCoefficients.map { case (a, b) => - elemsList.map { elem: Int => - ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME - }.min.toDouble - } - // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 - hashValues.map(Vectors.dense(_)) + override protected[ml] def hashFunction(elems: Vector): Array[Vector] = { + require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") + val elemsList = elems.toSparse.indices.toList + val hashValues = randCoefficients.map { case (a, b) => + elemsList.map { elem: Int => + ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME + }.min.toDouble } + // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 + hashValues.map(Vectors.dense(_)) } @Since("2.1.0") @@ -205,7 +203,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { .map(tuple => (tuple(0), tuple(1))).toArray val model = new MinHashLSHModel(metadata.uid, randCoefficients) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index f648deced54cd..2e0ae4af66f06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { .select("originalMin", "originalMax") .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index c8760f9dc178f..e0772d5af20a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -65,7 +65,8 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + s"Input type must be ${ArrayType(StringType).catalogString} but got " + + inputType.catalogString) } override protected def outputDataType: DataType = new ArrayType(StringType, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 5ab6c2dde667a..27e4869a020b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -85,7 +85,8 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + s"Input column must be of type ${NumericType.simpleString} but got " + + schema(inputColName).dataType.catalogString) require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index bd1e3426c8780..4a44f3186538d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { .head() val categorySizes = data.getAs[Seq[Int]](0).toArray val model = new OneHotEncoderModel(metadata.uid, categorySizes) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 4143d864d7930..8172491a517d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] { new PCAModel(metadata.uid, pc.asML, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 3b4c25478fb1d..56e2c543d100a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -253,35 +253,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - private[QuantileDiscretizer] - class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index e214765e3307f..346e1823f00b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -394,7 +394,7 @@ class RFormulaModel private[feature]( require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], - "Label column already exists and is not of type NumericType.") + s"Label column already exists and is not of type ${NumericType.simpleString}.") } @Since("2.0.0") @@ -446,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -510,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) - DefaultParamsReader.getAndSetParams(pruner, metadata) + metadata.getAndSetParams(pruner) pruner } } @@ -602,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) - DefaultParamsReader.getAndSetParams(rewriter, metadata) + metadata.getAndSetParams(rewriter) rewriter } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 8f125d8fd51d2..91b0707dec3f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3fcd84c029e61..94640a5cbe310 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.feature +import java.util.Locale + import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -84,7 +86,27 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false) + /** + * Locale of the input for case insensitive matching. Ignored when [[caseSensitive]] + * is true. + * Default: Locale.getDefault.toString + * @group param + */ + @Since("2.4.0") + val locale: Param[String] = new Param[String](this, "locale", + "Locale of the input for case insensitive matching. Ignored when caseSensitive is true.", + ParamValidators.inArray[String](Locale.getAvailableLocales.map(_.toString))) + + /** @group setParam */ + @Since("2.4.0") + def setLocale(value: String): this.type = set(locale, value) + + /** @group getParam */ + @Since("2.4.0") + def getLocale: String = $(locale) + + setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive -> false, locale -> Locale.getDefault.toString) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -95,8 +117,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String terms.filter(s => !stopWordsSet.contains(s)) } } else { - // TODO: support user locale (SPARK-15064) - val toLower = (s: String) => if (s != null) s.toLowerCase else s + val lc = new Locale($(locale)) + val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s val lowerStopWords = $(stopWords).map(toLower(_)).toSet udf { terms: Seq[String] => terms.filter(s => !lowerStopWords.contains(toLower(s))) @@ -109,8 +131,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + require(inputType.sameType(ArrayType(StringType)), "Input type must be " + + s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 67cdb097217a2..a833d8b270cf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { .head() val labels = data.getAs[Seq[String]](0).toArray val model = new StringIndexerModel(metadata.uid, labels) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index cfaf6c0e610b3..aede1f812a552 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -40,7 +40,8 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType == StringType, s"Input type must be string type but got $inputType.") + require(inputType == StringType, + s"Input type must be ${StringType.catalogString} type but got ${inputType.catalogString}.") } override protected def outputDataType: DataType = new ArrayType(StringType, true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 4061154b39c14..57e23d5072b88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -162,7 +162,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) schema(name).dataType match { case _: NumericType | BooleanType => None case t if t.isInstanceOf[VectorUDT] => None - case other => Some(s"Data type $other of column $name is not supported.") + case other => Some(s"Data type ${other.catalogString} of column $name is not supported.") } } if (incorrectColumns.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e6ec4e2e36ff0..0e7396a621dbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { val numFeatures = data.getAs[Int](0) val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index fe3306e1e50d6..fc9996d69ba72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { } val model = new Word2VecModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 3d041fc80eb7f..85c483c387ad8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -26,6 +26,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasPredictionCol import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth} import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset @@ -106,7 +107,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.catalogString}.") SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } @@ -158,9 +159,12 @@ class FPGrowth @Since("2.2.0") ( genericFit(dataset) } - private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = instrumented { instr => val handlePersistence = dataset.storageLevel == StorageLevel.NONE + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, params: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) @@ -335,7 +339,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) val model = new FPGrowthModel(metadata.uid, frequentItems) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala new file mode 100644 index 0000000000000..bd1c1a8885201 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} + +/** + * :: Experimental :: + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth + * (see here). + * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to + * run the PrefixSpan algorithm. + * + * @see Sequential Pattern Mining + * (Wikipedia) + */ +@Since("2.4.0") +@Experimental +final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("prefixSpan")) + + /** + * Param for the minimal support level (default: `0.1`). + * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are + * identified as frequent sequential patterns. + * @group param + */ + @Since("2.4.0") + val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) " + + "times will be output.", ParamValidators.gtEq(0.0)) + + /** @group getParam */ + @Since("2.4.0") + def getMinSupport: Double = $(minSupport) + + /** @group setParam */ + @Since("2.4.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** + * Param for the maximal pattern length (default: `10`). + * @group param + */ + @Since("2.4.0") + val maxPatternLength = new IntParam(this, "maxPatternLength", + "The maximal length of the sequential pattern.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxPatternLength: Int = $(maxPatternLength) + + /** @group setParam */ + @Since("2.4.0") + def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) + + /** + * Param for the maximum number of items (including delimiters used in the internal storage + * format) allowed in a projected database before local processing (default: `32000000`). + * If a projected database exceeds this size, another iteration of distributed prefix growth + * is run. + * @group param + */ + @Since("2.4.0") + val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the internal storage format) " + + "allowed in a projected database before local processing. If a projected database exceeds " + + "this size, another iteration of distributed prefix growth is run.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize) + + /** @group setParam */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + + /** + * Param for the name of the sequence column in dataset (default "sequence"), rows with + * nulls in this column are ignored. + * @group param + */ + @Since("2.4.0") + val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.") + + /** @group getParam */ + @Since("2.4.0") + def getSequenceCol: String = $(sequenceCol) + + /** @group setParam */ + @Since("2.4.0") + def setSequenceCol(value: String): this.type = set(sequenceCol, value) + + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, + sequenceCol -> "sequence") + + /** + * :: Experimental :: + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * + * @param dataset A dataset or a dataframe containing a sequence column which is + * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset. + * @return A `DataFrame` that contains columns of sequence and corresponding frequency. + * The schema of it will be: + * - `sequence: ArrayType(ArrayType(T))` (T is the item type) + * - `freq: Long` + */ + @Since("2.4.0") + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + val sequenceColParam = $(sequenceCol) + val inputType = dataset.schema(sequenceColParam).dataType + require(inputType.isInstanceOf[ArrayType] && + inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], + s"The input column must be ArrayType and the array element type must also be ArrayType, " + + s"but got $inputType.") + + val data = dataset.select(sequenceColParam) + val sequences = data.where(col(sequenceColParam).isNotNull).rdd + .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) + + val mllibPrefixSpan = new mllibPrefixSpan() + .setMinSupport($(minSupport)) + .setMaxPatternLength($(maxPatternLength)) + .setMaxLocalProjDBSize($(maxLocalProjDBSize)) + + val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) + val schema = StructType(Seq( + StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) + + freqSequences + } + + @Since("2.4.0") + override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala index 8c975a2fba8ca..1fae1dc04ad7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala @@ -38,13 +38,17 @@ private object RecursiveFlag { */ def withRecursiveFlag[T](value: Boolean, spark: SparkSession)(f: => T): T = { val flagName = FileInputFormat.INPUT_DIR_RECURSIVE + // scalastyle:off hadoopconfiguration val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration val old = Option(hadoopConf.get(flagName)) hadoopConf.set(flagName, value.toString) try f finally { - old match { - case Some(v) => hadoopConf.set(flagName, v) - case None => hadoopConf.unset(flagName) + // avoid false positive of DLS_DEAD_LOCAL_STORE_IN_RETURN by SpotBugs + if (old.isDefined) { + hadoopConf.set(flagName, old.get) + } else { + hadoopConf.unset(flagName) } } } @@ -96,7 +100,9 @@ private object SamplePathFilter { val sampleImages = sampleRatio < 1 if (sampleImages) { val flagName = FileInputFormat.PATHFILTER_CLASS + // scalastyle:off hadoopconfiguration val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration val old = Option(hadoopConf.getClass(flagName, null)) hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio) hadoopConf.setLong(SamplePathFilter.seedParam, seed) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 6961b45f55e4d..572b8cf0051b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -61,9 +61,12 @@ private[ml] class IterativelyReweightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, - val tol: Double) extends Logging with Serializable { + val tol: Double) extends Serializable { - def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { + def fit( + instances: RDD[OffsetInstance], + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[IterativelyReweightedLeastSquares])): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -83,7 +86,8 @@ private[ml] class IterativelyReweightedLeastSquares( // Estimate new model model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + standardizeFeatures = false, standardizeLabel = false) + .fit(newInstances, instr = instr) // Check convergence val oldCoefficients = oldModel.coefficients @@ -96,14 +100,14 @@ private[ml] class IterativelyReweightedLeastSquares( if (maxTol < tol) { converged = true - logInfo(s"IRLS converged in $iter iterations.") + instr.logInfo(s"IRLS converged in $iter iterations.") } - logInfo(s"Iteration $iter : relative tolerance = $maxTol") + instr.logInfo(s"Iteration $iter : relative tolerance = $maxTol") iter = iter + 1 if (iter == maxIter) { - logInfo(s"IRLS reached the max number of iterations: $maxIter.") + instr.logInfo(s"IRLS reached the max number of iterations: $maxIter.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index c5c9c8eb2bd29..1b7c15f1f0a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -81,13 +81,11 @@ private[ml] class WeightedLeastSquares( val standardizeLabel: Boolean, val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, val maxIter: Int = 100, - val tol: Double = 1e-6) extends Logging with Serializable { + val tol: Double = 1e-6 + ) extends Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") - if (regParam == 0.0) { - logWarning("regParam is zero, which might cause numerical instability and overfitting.") - } require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") @@ -96,10 +94,17 @@ private[ml] class WeightedLeastSquares( /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. */ - def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + def fit( + instances: RDD[Instance], + instr: OptionalInstrumentation = OptionalInstrumentation.create(classOf[WeightedLeastSquares]) + ): WeightedLeastSquaresModel = { + if (regParam == 0.0) { + instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() - logInfo(s"Number of instances: ${summary.count}.") + instr.logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k val numFeatures = summary.k val triK = summary.triK @@ -114,11 +119,12 @@ private[ml] class WeightedLeastSquares( if (rawBStd == 0) { if (fitIntercept || rawBBar == 0.0) { if (rawBBar == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } val coefficients = new DenseVector(Array.ofDim(numFeatures)) @@ -128,7 +134,7 @@ private[ml] class WeightedLeastSquares( } else { require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + "zero. Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. Consider setting " + + instr.logWarning(s"The standard deviation of the label is zero. Consider setting " + s"fitIntercept=true.") } } @@ -256,7 +262,7 @@ private[ml] class WeightedLeastSquares( // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to // Quasi-Newton solver. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => - logWarning("Cholesky solver failed due to singular covariance matrix. " + + instr.logWarning("Cholesky solver failed due to singular covariance matrix. " + "Retrying with Quasi-Newton solver.") // ab and aa were modified in place, so reconstruct them val _aa = getAtA(aaBarValues, aBarValues) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9a83a5882ce29..e6c347ed17c15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable { } /** Internal param map for user-supplied values. */ - private val paramMap: ParamMap = ParamMap.empty + private[ml] val paramMap: ParamMap = ParamMap.empty /** Internal param map for default values. */ - private val defaultParamMap: ParamMap = ParamMap.empty + private[ml] val defaultParamMap: ParamMap = ParamMap.empty /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { @@ -905,6 +905,15 @@ trait Params extends Identifiable with Serializable { } } +private[ml] object Params { + /** + * Sets a default param value for a `Params`. + */ + private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = { + params.defaultParamMap.put(param -> value) + } +} + /** * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index b9c3170cc3c28..7e08675f834da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -95,7 +95,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), isValid = "(value: String) => " + - "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"), + ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " + + "each row is for training or for validation. False indicates training; true indicates " + + "validation.") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 282ea6ebcbf7f..5928a0749f738 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -523,4 +523,21 @@ trait HasDistanceMeasure extends Params { /** @group getParam */ final def getDistanceMeasure: String = $(distanceMeasure) } + +/** + * Trait for shared param validationIndicatorCol. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasValidationIndicatorCol extends Params { + + /** + * Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.. + * @group param + */ + final val validationIndicatorCol: Param[String] = new Param[String](this, "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.") + + /** @group getParam */ + final def getValidationIndicatorCol: String = $(validationIndicatorCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 80d03ab03c87d..48485e02edda8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -59,13 +59,13 @@ private[r] class AFTSurvivalRegressionWrapper private ( private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalRegressionWrapper] { + private val FORMULA_REGEXP = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r + private def formulaRewrite(formula: String): (String, String) = { var rewritedFormula: String = null var censorCol: String = null - - val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r try { - val regex(label, censor, features) = formula + val FORMULA_REGEXP(label, censor, features) = formula // TODO: Support dot operator. if (features.contains(".")) { throw new UnsupportedOperationException( diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 81a8f50761e0e..ffe592789b3cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -39,6 +39,7 @@ import org.apache.spark.ml.linalg.BLAS import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD @@ -529,7 +530,7 @@ object ALSModel extends MLReadable[ALSModel] { val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -654,7 +655,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] } @Since("2.0.0") - override def fit(dataset: Dataset[_]): ALSModel = { + override def fit(dataset: Dataset[_]): ALSModel = instrumented { instr => transformSchema(dataset.schema) import dataset.sparkSession.implicits._ @@ -666,8 +667,9 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } - val instr = Instrumentation.create(this, ratings) - instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, itemCol, ratingCol, predictionCol, maxIter, regParam, nonnegative, checkpointInterval, seed, intermediateStorageLevel, finalStorageLevel) @@ -681,7 +683,6 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val userDF = userFactors.toDF("id", "features") val itemDF = itemFactors.toDF("id", "features") val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) - instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4b46c3831d75f..8d6e36697d2cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils @@ -210,7 +211,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } @Since("2.0.0") - override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -229,15 +230,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) val numFeatures = featuresStd.size - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, censorCol, predictionCol, quantilesCol, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol, fitIntercept, maxIter, tol, aggregationDepth) instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length) instr.logNumFeatures(numFeatures) + instr.logNumExamples(featuresSummarizer.count) if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { - logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } @@ -284,10 +287,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val coefficients = Vectors.dense(rawCoefficients) val intercept = parameters(1) val scale = math.exp(parameters(0)) - val model = copyValues(new AFTSurvivalRegressionModel(uid, coefficients, - intercept, scale).setParent(this)) - instr.logSuccess(model) - model + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale).setParent(this)) } @Since("1.6.0") @@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 5cef5c9f21f1e..018290f81842f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -99,37 +100,36 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("2.0.0") def setVarianceCol(value: String): this.type = set(varianceCol, value) - override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { + override protected def train( + dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(params: _*) + instr.logPipelineStage(this) + instr.logDataset(oldDataset) + instr.logParams(this, params: _*) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeRegressionModel] } /** (private[ml]) Train a decision tree on an RDD */ private[ml] def train( data: RDD[LabeledPoint], oldStrategy: OldStrategy, - featureSubsetStrategy: String): DecisionTreeRegressionModel = { - val instr = Instrumentation.create(this, data) - instr.logParams(params: _*) + featureSubsetStrategy: String): DecisionTreeRegressionModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(data) + instr.logParams(this, params: _*) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeRegressionModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -282,7 +282,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) val model = new DecisionTreeRegressionModel(metadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 834aaa0e362d1..3305881b0ccc6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -31,10 +31,11 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ /** @@ -145,24 +146,44 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - override protected def train(dataset: Dataset[_]): GBTRegressionModel = { + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + + override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val numFeatures = oldDataset.first().features.size + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + + val (trainDataset, validationDataset) = if (withValidation) { + ( + extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))), + extractLabeledPoints(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (extractLabeledPoints(dataset), null) + } + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) - val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) - instr.logSuccess(m) - m + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.0") @@ -269,6 +290,21 @@ class GBTRegressionModel private[ml]( new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + * @param loss The loss function used to compute error. Supported options: squared, absolute + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, + convertToOldLossType(loss), OldAlgo.Regression) + } + @Since("2.0.0") override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } @@ -311,7 +347,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } @@ -319,7 +355,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 9f1f2405c428e..abb60ea205751 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -373,13 +374,15 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val @Since("2.0.0") def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) - override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { + override protected def train( + dataset: Dataset[_]): GeneralizedLinearRegressionModel = instrumented { instr => val familyAndLink = FamilyAndLink(this) val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol, - family, solver, fitIntercept, link, maxIter, regParam, tol) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, offsetCol, predictionCol, + linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol) instr.logNumFeatures(numFeatures) if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { @@ -404,7 +407,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - val wlsModel = optimizer.fit(instances) + val wlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) @@ -418,10 +421,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val OffsetInstance(label, weight, offset, features) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). - val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam), + instr = OptionalInstrumentation.create(instr)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) - val irlsModel = optimizer.fit(instances) + val irlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) @@ -430,7 +434,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val model.setSummary(Some(trainingSummary)) } - instr.logSuccess(model) model } @@ -471,6 +474,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] val epsilon: Double = 1E-16 + private[regression] def ylogy(y: Double, mu: Double): Double = { + if (y == 0) 0.0 else y * math.log(y / mu) + } + /** * Wrapper of family and link combination used in the model. */ @@ -488,7 +495,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def initialize( instances: RDD[OffsetInstance], fitIntercept: Boolean, - regParam: Double): WeightedLeastSquaresModel = { + regParam: Double, + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[GeneralizedLinearRegression]) + ): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) val eta = predict(mu) - instance.offset @@ -497,7 +507,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine // TODO: Make standardizeFeatures and standardizeLabel configurable. val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - .fit(newInstances) + .fit(newInstances, instr) initialModel } @@ -505,14 +515,13 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * The reweight function used to update working labels and weights * at each iteration of [[IterativelyReweightedLeastSquares]]. */ - val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = { - (instance: OffsetInstance, model: WeightedLeastSquaresModel) => { - val eta = model.predict(instance.features) + instance.offset - val mu = fitted(eta) - val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) - val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) - (newLabel, newWeight) - } + def reweightFunc( + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + instance.offset + val mu = fitted(eta) + val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) + val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (newLabel, newWeight) } } @@ -725,10 +734,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) - private def ylogy(y: Double, mu: Double): Double = { - if (y == 0) 0.0 else y * math.log(y / mu) - } - override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu)) } @@ -783,7 +788,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu override def deviance(y: Double, mu: Double, weight: Double): Double = { - 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + 2.0 * weight * (ylogy(y, mu) - (y - mu)) } override def aic( @@ -1146,7 +1151,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8faab52ea474b..8b9233dcdc4d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD @@ -161,15 +162,16 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) @Since("2.0.0") - override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = instrumented { instr => transformSchema(dataset.schema, logging = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) val handlePersistence = dataset.storageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic) instr.logNumFeatures(1) val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) @@ -177,9 +179,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri if (handlePersistence) instances.unpersist() - val model = copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) - instr.logSuccess(model) - model + copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) } @Since("1.5.0") @@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val model = new IsotonicRegressionModel( metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f67d9d831f327..ce6c12cc368dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -37,6 +37,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} @@ -315,7 +316,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setEpsilon(value: Double): this.type = set(epsilon, value) setDefault(epsilon -> 1.35) - override protected def train(dataset: Dataset[_]): LinearRegressionModel = { + override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr => // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) @@ -326,9 +327,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Instance(label, weight, features) } - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, solver, tol, elasticNetParam, - fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss, epsilon) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, solver, tol, + elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss, + epsilon) instr.logNumFeatures(numFeatures) if ($(loss) == SquaredError && (($(solver) == Auto && @@ -339,7 +342,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = $(elasticNetParam), $(standardization), true, solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) - val model = optimizer.fit(instances) + val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) // When it is trained by WeightedLeastSquares, training summary does not // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) @@ -353,9 +356,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) - lrModel.setSummary(Some(trainingSummary)) - instr.logSuccess(lrModel) - return lrModel + return lrModel.setSummary(Some(trainingSummary)) } val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -378,6 +379,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) + + instr.logNumExamples(ySummarizer.count) + instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean) + instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd) + if (rawYStd == 0.0) { if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with @@ -385,11 +391,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of // the fitIntercept. if (yMean == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } if (handlePersistence) instances.unpersist() @@ -409,13 +416,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Array(0D), Array(0D)) - model.setSummary(Some(trainingSummary)) - instr.logSuccess(model) - return model + return model.setSummary(Some(trainingSummary)) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") - logWarning(s"The standard deviation of the label is zero. " + + instr.logWarning(s"The standard deviation of the label is zero. " + "Consider setting fitIntercept=true.") } } @@ -430,7 +435,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LinearRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LinearRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -522,7 +527,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -590,8 +595,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String objectiveHistory) model.setSummary(Some(trainingSummary)) - instr.logSuccess(model) - model } @Since("1.4.0") @@ -746,7 +749,7 @@ private class InternalLinearRegressionModelWriter /** A writer for LinearRegression that handles the "pmml" format */ private class PMMLLinearRegressionModelWriter - extends MLWriterFormat with MLFormatRegister { + extends MLWriterFormat with MLFormatRegister { override def format(): String = "pmml" @@ -799,7 +802,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f77398ba2a22..35875724b3cfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -114,15 +115,17 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { + override protected def train( + dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, numTrees, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) @@ -131,9 +134,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .map(_.asInstanceOf[DecisionTreeRegressionModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestRegressionModel(uid, trees, numFeatures) - instr.logSuccess(m) - m + instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) + new RandomForestRegressionModel(uid, trees, numFeatures) } @Since("1.4.0") @@ -276,14 +278,14 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 4e84ff044f55e..39dcd911a0814 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -154,7 +154,7 @@ private[libsvm] class LibSVMFileFormat (file: PartitionedFile) => { val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val points = linesReader .map(_.toString.trim) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 056a94b351f79..4cdd17266b771 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -77,7 +77,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} * the heaviest part of the computation. In general, this implementation is bound by either * the cost of statistics computation on workers or by communicating the sufficient statistics. */ -private[spark] object RandomForest extends Logging { +private[spark] object RandomForest extends Logging with Serializable { /** * Train a random forest. @@ -91,7 +91,7 @@ private[spark] object RandomForest extends Logging { numTrees: Int, featureSubsetStrategy: String, seed: Long, - instr: Option[Instrumentation[_]], + instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { @@ -108,9 +108,11 @@ private[spark] object RandomForest extends Logging { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumClasses(metadata.numClasses) + instrumentation.logNumExamples(metadata.numExamples) case None => logInfo("numFeatures: " + metadata.numFeatures) logInfo("numClasses: " + metadata.numClasses) + logInfo("numExamples: " + metadata.numExamples) } // Find the splits and the corresponding bins (interval between the splits) using a sample @@ -405,7 +407,7 @@ private[spark] object RandomForest extends Logging { metadata.isMulticlassWithCategoricalFeatures) logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) - /** + /* * Performs a sequential aggregation over a partition for a particular tree and node. * * For each feature, the aggregate sufficient statistics are updated for the relevant @@ -436,7 +438,7 @@ private[spark] object RandomForest extends Logging { } } - /** + /* * Performs a sequential aggregation over a partition. * * Each data point contributes to one node. For each feature, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 81b6222acc7ce..00157fe63af41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -21,6 +21,7 @@ import java.util.Locale import scala.util.Try +import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -460,18 +461,34 @@ private[ml] trait RandomForestRegressorParams * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { - - /* TODO: Add this doc when we add this param. SPARK-7132 - * Threshold for stopping early when runWithValidation is used. - * If the error rate on the validation input changes by less than the validationTol, - * then learning will stop early (before [[numIterations]]). - * This parameter is ignored when run is used. - * (default = 1e-5) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize + with HasValidationIndicatorCol { + + /** + * Threshold for stopping early when fit with validation is used. + * (This parameter is ignored when fit without validation is used.) + * The decision to stop early is decided based on this logic: + * If the current loss on the validation set is greater than 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is less than or equal to 0.01, + * the diff of validation error is compared to absolute tolerance which is + * validationTol * 0.01. * @group param + * @see validationIndicatorCol */ - // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") - // validationTol -> 1e-5 + @Since("2.4.0") + final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", + "Threshold for stopping early when fit with validation is used." + + "If the error rate on the validation input changes by less than the validationTol," + + "then learning will stop early (before `maxIter`)." + + "This parameter is ignored when fit without validation is used.", + ParamValidators.gtEq(0.0) + ) + + /** @group getParam */ + @Since("2.4.0") + final def getValidationTol: Double = $(validationTol) /** * @deprecated This method is deprecated and will be removed in 3.0.0. @@ -497,7 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") @@ -507,7 +524,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) // NOTE: The old API does not support "seed" so we ignore it. - new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol) } /** Get old Gradient Boosting Loss type */ @@ -579,7 +596,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { - getLossType match { + convertToOldLossType(getLossType) + } + + private[ml] def convertToOldLossType(loss: String): OldLoss = { + loss match { case "squared" => OldSquaredError case "absolute" => OldAbsoluteError case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index c2826dcc08634..e60a14f976a5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -118,7 +119,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): CrossValidatorModel = { + override def fit(dataset: Dataset[_]): CrossValidatorModel = instrumented { instr => val schema = dataset.schema transformSchema(schema, logging = true) val sparkSession = dataset.sparkSession @@ -129,8 +130,9 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) // Create execution context based on $(parallelism) val executionContext = getExecutionContext - val instr = Instrumentation.create(this, dataset) - instr.logParams(numFolds, seed, parallelism) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, numFolds, seed, parallelism) logTuningParams(instr) val collectSubModelsParam = $(collectSubModels) @@ -144,7 +146,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() - logDebug(s"Train split $splitIndex with multiple sets of parameters.") + instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => @@ -155,7 +157,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -169,14 +171,13 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits - logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best cross-validation metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics) .setSubModels(subModels).setParent(this)) } @@ -234,8 +235,7 @@ object CrossValidator extends MLReadable[CrossValidator] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(cv, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps"))) cv } } @@ -424,8 +424,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8d1b9a8ddab59..8b251197afbef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils @@ -117,7 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { + override def fit(dataset: Dataset[_]): TrainValidationSplitModel = instrumented { instr => val schema = dataset.schema transformSchema(schema, logging = true) val est = $(estimator) @@ -127,8 +128,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Create execution context based on $(parallelism) val executionContext = getExecutionContext - val instr = Instrumentation.create(this, dataset) - instr.logParams(trainRatio, seed, parallelism) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, trainRatio, seed, parallelism) logTuningParams(instr) val Array(trainingDataset, validationDataset) = @@ -143,7 +145,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } else None // Fit models in a Future for training in parallel - logDebug(s"Train split with multiple sets of parameters.") + instr.logDebug(s"Train split with multiple sets of parameters.") val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] @@ -153,7 +155,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -165,14 +167,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.unpersist() validationDataset.unpersist() - logInfo(s"Train validation split metrics: ${metrics.toSeq}") + instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best train validation split metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) .setSubModels(subModels).setParent(this)) } @@ -228,8 +229,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(tvs, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps"))) tvs } } @@ -407,8 +407,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 363304ef10147..135828815504a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -80,7 +80,7 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** * Instrumentation logging for tuning params including the inner estimator and evaluator info. */ - protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { + protected def logTuningParams(instrumentation: Instrumentation): Unit = { instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName) instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala new file mode 100644 index 0000000000000..6af4b3ebc2cc2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, Dataset, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} + + +private[spark] object DatasetUtils { + + /** + * Cast a column in a Dataset to Vector type. + * + * The supported data types of the input column are + * - Vector + * - float/double type Array. + * + * Note: The returned column does not have Metadata. + * + * @param dataset input DataFrame + * @param colName column name. + * @return Vector column + */ + def columnToVector(dataset: Dataset[_], colName: String): Column = { + val columnDataType = dataset.schema(colName).dataType + columnDataType match { + case _: VectorUDT => col(colName) + case fdt: ArrayType => + val transferUDF = fdt.elementType match { + case _: FloatType => udf(f = (vector: Seq[Float]) => { + val inputArray = Array.fill[Double](vector.size)(0.0) + vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble) + Vectors.dense(inputArray) + }) + case _: DoubleType => udf((vector: Seq[Double]) => { + Vectors.dense(vector.toArray) + }) + case other => + throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector") + } + transferUDF(col(colName)) + case other => + throw new IllegalArgumentException(s"$other column cannot be cast to Vector") + } + } + + def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = { + dataset.select(columnToVector(dataset, colName)) + .rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index e694bc27b2f1e..49654918bd8f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,42 +19,61 @@ package org.apache.spark.ml.util import java.util.UUID +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal + import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log * useful information during this session. - * - * A new instance is expected to be created within fit(). - * - * @param estimator the estimator that is being fit - * @param dataset the training dataset - * @tparam E the type of the estimator */ -private[spark] class Instrumentation[E <: Estimator[_]] private ( - estimator: E, dataset: RDD[_]) extends Logging { +private[spark] class Instrumentation private () extends Logging { private val id = UUID.randomUUID() - private val prefix = { - val className = estimator.getClass.getSimpleName - s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " + private val shortId = id.toString.take(8) + private[util] val prefix = s"[$shortId] " + + /** + * Log some info about the pipeline stage being fit. + */ + def logPipelineStage(stage: PipelineStage): Unit = { + // estimator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + val className = Utils.getSimpleName(stage.getClass) + logInfo(s"Stage class: $className") + logInfo(s"Stage uid: ${stage.uid}") } - init() + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd) - private def init(): Unit = { - log(s"training: numPartitions=${dataset.partitions.length}" + + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: RDD[_]): Unit = { + logInfo(s"training: numPartitions=${dataset.partitions.length}" + s" storageLevel=${dataset.getStorageLevel}") } + /** + * Logs a debug message with a prefix that uniquely identifies the training session. + */ + override def logDebug(msg: => String): Unit = { + super.logDebug(prefix + msg) + } + /** * Logs a warning message with a prefix that uniquely identifies the training session. */ @@ -76,23 +95,18 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( super.logInfo(prefix + msg) } - /** - * Alias for logInfo, see above. - */ - def log(msg: String): Unit = logInfo(msg) - /** * Logs the value of the given parameters for the estimator being used in this session. */ - def logParams(params: Param[_]*): Unit = { + def logParams(hasParams: Params, params: Param[_]*): Unit = { val pairs: Seq[(String, JValue)] = for { p <- params - value <- estimator.get(p) + value <- hasParams.get(p) } yield { val cast = p.asInstanceOf[Param[Any]] p.name -> parse(cast.jsonEncode(value)) } - log(compact(render(map2jvalue(pairs.toMap)))) + logInfo(compact(render(map2jvalue(pairs.toMap)))) } def logNumFeatures(num: Long): Unit = { @@ -103,22 +117,51 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( logNamedValue(Instrumentation.loggerTags.numClasses, num) } + def logNumExamples(num: Long): Unit = { + logNamedValue(Instrumentation.loggerTags.numExamples, num) + } + /** * Logs the value with customized name field. */ def logNamedValue(name: String, value: String): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Long): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) + } + + def logNamedValue(name: String, value: Double): Unit = { + logInfo(compact(render(name -> value))) + } + + def logNamedValue(name: String, value: Array[String]): Unit = { + logInfo(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Long]): Unit = { + logInfo(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Double]): Unit = { + logInfo(compact(render(name -> compact(render(value.toSeq))))) } + /** * Logs the successful completion of the training session. */ - def logSuccess(model: Model[_]): Unit = { - log(s"training finished") + def logSuccess(): Unit = { + logInfo("training finished") + } + + /** + * Logs an exception raised during a training session. + */ + def logFailure(e: Throwable): Unit = { + val msg = e.getStackTrace.mkString("\n") + super.logError(msg) } } @@ -131,22 +174,71 @@ private[spark] object Instrumentation { val numFeatures = "numFeatures" val numClasses = "numClasses" val numExamples = "numExamples" + val meanOfLabels = "meanOfLabels" + val varianceOfLabels = "varianceOfLabels" + } + + def instrumented[T](body: (Instrumentation => T)): T = { + val instr = new Instrumentation() + Try(body(instr)) match { + case Failure(NonFatal(e)) => + instr.logFailure(e) + throw e + case Success(result) => + instr.logSuccess() + result + } + } +} + +/** + * A small wrapper that contains an optional `Instrumentation` object. + * Provide some log methods, if the containing `Instrumentation` object is defined, + * will log via it, otherwise will log via common logger. + */ +private[spark] class OptionalInstrumentation private( + val instrumentation: Option[Instrumentation], + val className: String) extends Logging { + + protected override def logName: String = className + + override def logInfo(msg: => String) { + instrumentation match { + case Some(instr) => instr.logInfo(msg) + case None => super.logInfo(msg) + } + } + + override def logWarning(msg: => String) { + instrumentation match { + case Some(instr) => instr.logWarning(msg) + case None => super.logWarning(msg) + } } + override def logError(msg: => String) { + instrumentation match { + case Some(instr) => instr.logError(msg) + case None => super.logError(msg) + } + } +} + +private[spark] object OptionalInstrumentation { + /** - * Creates an instrumentation object for a training session. + * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. */ - def create[E <: Estimator[_]]( - estimator: E, dataset: Dataset[_]): Instrumentation[E] = { - create[E](estimator, dataset.rdd) + def create(instr: Instrumentation): OptionalInstrumentation = { + new OptionalInstrumentation(Some(instr), instr.prefix) } /** - * Creates an instrumentation object for a training session. + * Creates an `OptionalInstrumentation` object from a `Class` object. + * The created `OptionalInstrumentation` object will log messages via common logger and use the + * specified class name as logger name. */ - def create[E <: Estimator[_]]( - estimator: E, dataset: RDD[_]): Instrumentation[E] = { - new Instrumentation[E](estimator, dataset) + def create(clazz: Class[_]): OptionalInstrumentation = { + new OptionalInstrumentation(None, clazz.getName.stripSuffix("$")) } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7edcd498678cc..72a60e04360d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -39,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, VersionUtils} /** * Trait for `MLWriter` and `MLReader`. @@ -421,6 +421,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid + * - defaultParamMap * - paramMap * - (optionally, extra metadata) * @@ -453,15 +454,20 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.paramMap.toSeq + val defaultParams = instance.defaultParamMap.toSeq val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) + val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("defaultParamMap" -> jsonDefaultParams) val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -488,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] { val cls = Utils.classForName(metadata.className) val instance = cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] - DefaultParamsReader.getAndSetParams(instance, metadata) + metadata.getAndSetParams(instance) instance.asInstanceOf[T] } } @@ -499,6 +505,8 @@ private[ml] object DefaultParamsReader { * All info from metadata file. * * @param params paramMap, as a `JValue` + * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4, + * this is `JNothing`. * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) */ @@ -508,27 +516,90 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + defaultParams: JValue, metadata: JValue, metadataJson: String) { + + private def getValueFromParams(params: JValue): Seq[(String, JValue)] = { + params match { + case JObject(pairs) => pairs + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. * This can be useful for getting a Param value before an instance of `Params` - * is available. + * is available. This will look up `params` first, if not existing then looking up + * `defaultParams`. */ def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats - params match { + + // Looking up for `params` first. + var pairs = getValueFromParams(params) + var foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + if (foundPairs.length == 0) { + // Looking up for `defaultParams` then. + pairs = getValueFromParams(defaultParams) + foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + } + assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + + foundPairs.map(_._2).head + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. + */ + def getAndSetParams( + instance: Params, + skipParams: Option[List[String]] = None): Unit = { + setParams(instance, skipParams, isDefault = false) + + // For metadata file prior to Spark 2.4, there is no default section. + val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion) + if (major > 2 || (major == 2 && minor >= 4)) { + setParams(instance, skipParams, isDefault = true) + } + } + + private def setParams( + instance: Params, + skipParams: Option[List[String]], + isDefault: Boolean): Unit = { + implicit val format = DefaultFormats + val paramsToSet = if (isDefault) defaultParams else params + paramsToSet match { case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head + pairs.foreach { case (paramName, jsonValue) => + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + if (isDefault) { + Params.setDefault(instance, param, value) + } else { + instance.set(param, value) + } + } + } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") + s"Cannot recognize JSON metadata: ${metadataJson}.") } } } @@ -561,43 +632,14 @@ private[ml] object DefaultParamsReader { val uid = (metadata \ "uid").extract[String] val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] + val defaultParams = metadata \ "defaultParamMap" val params = metadata \ "paramMap" if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) - } - - /** - * Extract Params from metadata, and set them in the instance. - * This works if all Params (except params included by `skipParams` list) implement - * [[org.apache.spark.ml.param.Param.jsonDecode()]]. - * - * @param skipParams The params included in `skipParams` won't be set. This is useful if some - * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] - * and need special handling. - * TODO: Move to [[Metadata]] method - */ - def getAndSetParams( - instance: Params, - metadata: Metadata, - skipParams: Option[List[String]] = None): Unit = { - implicit val format = DefaultFormats - metadata.params match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - if (skipParams == None || !skipParams.get.contains(paramName)) { - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) - } - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") - } + Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 334410c9620de..c3894ebdd1785 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.sql.types._ /** @@ -40,7 +41,8 @@ private[spark] object SchemaUtils { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.$message") + s"Column $colName must be of type ${dataType.catalogString} but was actually " + + s"${actualDataType.catalogString}.$message") } /** @@ -57,7 +59,8 @@ private[spark] object SchemaUtils { val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(dataTypes.exists(actualDataType.equals), s"Column $colName must be of type equal to one of the following types: " + - s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + s"${dataTypes.map(_.catalogString).mkString("[", ", ", "]")} but was actually of type " + + s"${actualDataType.catalogString}.$message") } /** @@ -70,8 +73,9 @@ private[spark] object SchemaUtils { msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + - s"NumericType but was actually of type $actualDataType.$message") + require(actualDataType.isInstanceOf[NumericType], + s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + + s"${actualDataType.catalogString}.$message") } /** @@ -101,4 +105,17 @@ private[spark] object SchemaUtils { require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") StructType(schema.fields :+ col) } + + /** + * Check whether the given column in the schema is one of the supporting vector type: Vector, + * Array[Float]. Array[Double] + * @param schema input schema + * @param colName column name + */ + def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + checkColumnTypes(schema, colName, typeCandidates) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b32d3f252ae59..db3f074ecfbac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -572,10 +572,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[java.lang.Iterable[Any]], minSupport: Double, numPartitions: Int): FPGrowthModel[Any] = { - val fpg = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartitions) - + val fpg = new FPGrowth(minSupport, numPartitions) val model = fpg.run(data.rdd.map(_.asScala.toArray)) new FPGrowthModelWrapper(model) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 98af487306dcc..80ab8eb9bc8b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -151,13 +152,10 @@ class BisectingKMeans private ( this } - /** - * Runs the bisecting k-means algorithm. - * @param input RDD of vectors - * @return model for the bisecting kmeans - */ - @Since("1.6.0") - def run(input: RDD[Vector]): BisectingKMeansModel = { + + private[spark] def run( + input: RDD[Vector], + instr: Option[Instrumentation]): BisectingKMeansModel = { if (input.getStorageLevel == StorageLevel.NONE) { logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if" + " its parent RDDs are also not cached.") @@ -171,6 +169,7 @@ class BisectingKMeans private ( val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } var assignments = vectors.map(v => (ROOT_INDEX, v)) var activeClusters = summarize(d, assignments, dMeasure) + instr.foreach(_.logNumExamples(activeClusters.values.map(_.size).sum)) val rootSummary = activeClusters(ROOT_INDEX) val n = rootSummary.size logInfo(s"Number of points: $n.") @@ -246,6 +245,16 @@ class BisectingKMeans private ( new BisectingKMeansModel(root, this.distanceMeasure) } + /** + * Runs the bisecting k-means algorithm. + * @param input RDD of vectors + * @return model for the bisecting kmeans + */ + @Since("1.6.0") + def run(input: RDD[Vector]): BisectingKMeansModel = { + run(input, None) + } + /** * Java-friendly version of `run()`. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b5b1be3490497..d967c672c581f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -235,7 +235,7 @@ class KMeans private ( private[spark] def run( data: RDD[Vector], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" @@ -264,7 +264,7 @@ class KMeans private ( */ private def runAlgorithm( data: RDD[VectorWithNorm], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { val sc = data.sparkContext @@ -299,7 +299,7 @@ class KMeans private ( val bcCenters = sc.broadcast(centers) // Find the new centers - val newCenters = data.mapPartitions { points => + val collected = data.mapPartitions { points => val thisCenters = bcCenters.value val dims = thisCenters.head.vector.size @@ -317,7 +317,13 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) - }.collectAsMap().mapValues { case (sum, count) => + }.collectAsMap() + + if (iteration == 0) { + instr.foreach(_.logNumExamples(collected.values.map(_._2).sum)) + } + + val newCenters = collected.mapValues { case (sum, count) => distanceMeasureInstance.centroid(sum, count) } @@ -348,7 +354,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector), distanceMeasure) + new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a78c21e838e44..d5c8188144ce2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,8 +36,10 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], - @Since("2.4.0") val distanceMeasure: String) +class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String, + @Since("2.4.0") val trainingCost: Double, + private[spark] val numIter: Int) extends Saveable with Serializable with PMMLExportable { private val distanceMeasureInstance: DistanceMeasure = @@ -46,9 +48,13 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("2.4.0") + private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = + this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1) + @Since("1.1.0") def this(clusterCenters: Array[Vector]) = - this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) + this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN, 0.0, -1) /** * A Java-friendly constructor that takes an Iterable of Vectors. @@ -182,7 +188,8 @@ object KMeansModel extends Loader[KMeansModel] { val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) - ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure))) + ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure) + ~ ("trainingCost" -> model.trainingCost))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => Cluster(id, p.vector) @@ -202,7 +209,8 @@ object KMeansModel extends Loader[KMeansModel] { val localCentroids = centroids.rdd.map(Cluster.apply).collect() assert(k == localCentroids.length) val distanceMeasure = (metadata \ "distanceMeasure").extract[String] - new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure) + val trainingCost = (metadata \ "trainingCost").extract[Double] + new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure, trainingCost, -1) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b8a6e94248421..f915062d77389 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, Utils} /** * Latent Dirichlet Allocation (LDA) model. @@ -194,6 +194,8 @@ class LocalLDAModel private[spark] ( override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { + private var seed: Long = Utils.random.nextLong() + @Since("1.3.0") override def k: Int = topics.numCols @@ -216,6 +218,21 @@ class LocalLDAModel private[spark] ( override protected def formatVersion = "1.0" + /** + * Random seed for cluster initialization. + */ + @Since("2.4.0") + def getSeed: Long = seed + + /** + * Set the random seed for cluster initialization. + */ + @Since("2.4.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, @@ -298,6 +315,7 @@ class LocalLDAModel private[spark] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + val gammaSeed = this.seed // Sum bound components for each document: // component for prob(tokens) + component for prob(document-topic distribution) @@ -306,7 +324,7 @@ class LocalLDAModel private[spark] ( val localElogbeta = ElogbetaBc.value var docBound = 0.0D val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id) val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) // E[log p(doc | theta, beta)] @@ -352,6 +370,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -362,7 +381,8 @@ class LocalLDAModel private[spark] ( expElogbetaBc.value, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed + id) (id, Vectors.dense(normalize(gamma, 1.0).toArray)) } } @@ -376,6 +396,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed (termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -386,7 +407,8 @@ class LocalLDAModel private[spark] ( expElogbeta, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } @@ -403,6 +425,7 @@ class LocalLDAModel private[spark] ( */ @Since("2.0.0") def topicDistribution(document: Vector): Vector = { + val gammaSeed = this.seed val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) if (document.numNonzeros == 0) { Vectors.zeros(this.k) @@ -412,7 +435,8 @@ class LocalLDAModel private[spark] ( expElogbeta, this.docConcentration.asBreeze, gammaShape, - this.k) + this.k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 693a2a31f026b..f8e5f3ed76457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape val optimizeDocConcentration = this.optimizeDocConcentration + val seed = randomGenerator.nextLong() // If and only if optimizeDocConcentration is set true, // we calculate logphat in the same pass as other statistics. // No calculation of loghat happens otherwise. @@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { None } - val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - - val stat = BDM.zeros[Double](k, vocabSize) - val logphatPartOption = logphatPartOptionBase() - var nonEmptyDocCount: Long = 0L - nonEmptyDocs.foreach { case (_, termCounts: Vector) => - nonEmptyDocCount += 1 - val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids) + sstats - logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) - } - Iterator((stat, logphatPartOption, nonEmptyDocCount)) + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex { + (index, docs) => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) + + val stat = BDM.zeros[Double](k, vocabSize) + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L + nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 + val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index) + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) + } + Iterator((stat, logphatPartOption, nonEmptyDocCount)) } val elementWiseSum = ( @@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { } override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta) + .setSeed(randomGenerator.nextLong()) } } @@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer { expElogbeta: BDM[Double], alpha: breeze.linalg.Vector[Double], gammaShape: Double, - k: Int): (BDV[Double], BDM[Double], List[Int]) = { + k: Int, + seed: Long): (BDV[Double], BDM[Double], List[Int]) = { val (ids: List[Int], cts: Array[Double]) = termCounts match { case v: DenseVector => ((0 until v.size).toList, v.values) case v: SparseVector => (v.indices.toList, v.values) } // Initialize the variational distribution q(theta|gamma) for the mini-batch + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed)) val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 3ca75e8cdb97a..ed8543da4d4ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ - * n_t+t &= n_t * a + m_t + * n_t+1 &= n_t * a + m_t * \end{align} * $$ * @@ -227,7 +227,7 @@ class StreamingKMeans @Since("1.2.0") ( require(centers.size == k, s"Number of initial centers must be ${k} but got ${centers.size}") require(weights.forall(_ >= 0), - s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]") + s"Weight for each initial center must be nonnegative but got [${weights.mkString(" ")}]") model = new StreamingKMeansModel(centers, weights) this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index f6b1143272d16..4f2b7e6f0764e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -162,7 +162,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { * */ @Since("1.3.0") -class FPGrowth private ( +class FPGrowth private[spark] ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 3f8d65a378e2c..7aed2f3bd8a61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears - * less than maxPatternLength will be output + * @param maxPatternLength the maximal length of the sequential pattern * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local * processing. If a projected database exceeds this size, another diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ac709ad72f0c0..7b49d4d0812f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -78,8 +78,13 @@ class MatrixFactorizationModel @Since("0.8.0") ( /** Predict the rating of one user for one product. */ @Since("0.8.0") def predict(user: Int, product: Int): Double = { - val userVector = userFeatures.lookup(user).head - val productVector = productFeatures.lookup(product).head + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + + val userVector = userFeatureSeq.head + val productVector = productFeatureSeq.head blas.ddot(rank, userVector, 1, productVector, 1) } @@ -164,9 +169,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the product is. */ @Since("1.1.0") - def recommendProducts(user: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) + def recommendProducts(user: Int, num: Int): Array[Rating] = { + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + MatrixFactorizationModel.recommend(userFeatureSeq.head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) + } /** * Recommends users to a product. That is, this returns users who are most likely to be @@ -181,9 +189,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the user is. */ @Since("1.1.0") - def recommendUsers(product: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) + def recommendUsers(product: Int, num: Int): Array[Rating] = { + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + MatrixFactorizationModel.recommend(productFeatureSeq.head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + } protected override val formatVersion: String = "1.0" diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index f0ee5496f9d1d..e6d2a8e2b900e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.RegressionLeafNode -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -365,6 +366,78 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } + test("model evaluateEachIteration") { + val gbt = new GBTClassifier() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType("logistic") + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTClassificationModel("gbt-cls-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses) + val model2 = new GBTClassificationModel("gbt-cls-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses) + + val evalArr = model3.evaluateEachIteration(validationData.toDF) + val remappedValidationData = validationData.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData, + model1.trees, model1.treeWeights, model1.getOldLossType) + val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData, + model2.trees, model2.treeWeights, model2.getOldLossType) + val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData, + model3.trees, model3.treeWeights, model3.getOldLossType) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTClassifier.supportedLossTypes) { + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val (errorWithoutValidation, errorWithValidation) = { + val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType), + GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees, + modelWithValidation.treeWeights, modelWithValidation.getOldLossType)) + } + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Classification) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 36b7e51f93d01..75c2aeb146786 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model.getFamily === family) } } + + test("toString") { + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0) + val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3" + assert(model.toString === expected) + } } object LogisticRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 02880f96ae6d9..1b7780e171e77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.ml.clustering -import org.apache.spark.{SparkException, SparkFunSuite} +import scala.language.existentials + +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset -class BisectingKMeansSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -65,10 +69,13 @@ class BisectingKMeansSuite // Verify fit does not fail on very sparse data val model = bkm.fit(sparseDataset) - val result = model.transform(sparseDataset) - val numClusters = result.select("prediction").distinct().collect().length - // Verify we hit the edge case - assert(numClusters < k && numClusters > 1) + + testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") { + rows => + val numClusters = rows.distinct.length + // Verify we hit the edge case + assert(numClusters < k && numClusters > 1) + } } test("setter/getter") { @@ -101,19 +108,16 @@ class BisectingKMeansSuite val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) assert(model.clusterCenters.length === k) - - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } + // Check validity of model summary val numRows = dataset.count() assert(model.hasSummary) @@ -129,6 +133,7 @@ class BisectingKMeansSuite assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 20) model.setSummary(None) assert(!model.hasSummary) @@ -182,6 +187,22 @@ class BisectingKMeansSuite model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) } + + test("BisectingKMeans with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new BisectingKMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } } object BisectingKMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 08b800b7e4183..13bed9dbe3e89 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.stat.distribution.MultivariateGaussian -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { - import testImplicits._ import GaussianMixtureSuite._ + import testImplicits._ final val k = 5 private val seed = 538009335 @@ -118,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.weights.length === k) assert(model.gaussians.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName, probabilityColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - // Check prediction matches the highest probability, and probabilities sum to one. - transformed.select(predictionColName, probabilityColName).collect().foreach { - case Row(pred: Int, prob: Vector) => + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName, probabilityColName) { + case Row(_, pred: Int, prob: Vector) => val probArray = prob.toArray val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2 assert(pred === predFromProb) @@ -150,6 +145,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 2) model.setSummary(None) assert(!model.hasSummary) @@ -256,6 +252,22 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues) assert(symmetricMatrix === expectedMatrix) } + + test("GaussianMixture with Array input") { + def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = { + val model = new GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.summary.logLikelihood + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueLikelihood = trainAndComputlogLikelihood(newDataset) + val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD) + val floatLikelihood = trainAndComputlogLikelihood(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueLikelihood ~== doubleLikelihood absTol 1e-6) + assert(trueLikelihood ~== floatLikelihood absTol 1e-6) + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 32830b39407ad..ccbceab53bb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,19 +17,26 @@ package org.apache.spark.ml.clustering +import scala.language.existentials import scala.util.Random -import org.apache.spark.{SparkException, SparkFunSuite} +import org.dmg.pmml.{ClusteringModel, PMML} + +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -103,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val model = kmeans.fit(dataset) assert(model.clusterCenters.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) @@ -126,10 +131,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(summary.predictions.columns.contains(c)) } assert(summary.cluster.columns === Array(predictionColName)) + assert(summary.trainingCost < 0.1) + assert(model.computeCost(dataset) == summary.trainingCost) val clusterSizes = summary.clusterSizes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 1) model.setSummary(None) assert(!model.hasSummary) @@ -143,9 +151,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) - Seq(featuresColName, predictionColName).foreach { column => - assert(transformed.columns.contains(column)) - } + assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName)) assert(model.getFeaturesCol == featuresColName) assert(model.getPredictionCol == predictionColName) } @@ -194,6 +200,23 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } + test("KMean with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } + + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) @@ -202,6 +225,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, KMeansSuite.allParamSettings, checkModelData) } + + test("pmml export") { + val clusterCenters = Array( + MLlibVectors.dense(1.0, 2.0, 6.0), + MLlibVectors.dense(1.0, 3.0, 0.0), + MLlibVectors.dense(1.0, 4.0, 6.0)) + val oldKmeansModel = new MLlibKMeansModel(clusterCenters) + val kmeansModel = new KMeansModel("", oldKmeansModel) + def checkModel(pmml: PMML): Unit = { + // Check the header description is what we expect + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark + // model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + testPMMLWrite(sc, kmeansModel, checkModel) + } } object KMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index e73bbc18d76bd..bbd5408c9fce3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.hadoop.fs.Path -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ - object LDASuite { def generateLDAData( spark: SparkSession, @@ -35,9 +34,8 @@ object LDASuite { vocabSize: Int): DataFrame = { val avgWC = 1 // average instances of each word in a doc val sc = spark.sparkContext - val rng = new java.util.Random() - rng.setSeed(1) val rdd = sc.parallelize(1 to rows).map { i => + val rng = new java.util.Random(i) Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) }.map(v => new TestRow(v)) spark.createDataFrame(rdd) @@ -60,7 +58,7 @@ object LDASuite { } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LDASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -185,16 +183,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.topicsMatrix.numCols === k) assert(!model.isDistributed) - // transform() - val transformed = model.transform(dataset) - val expectedColumns = Array("features", lda.getTopicDistributionCol) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - transformed.select(lda.getTopicDistributionCol).collect().foreach { r => - val topicDistribution = r.getAs[Vector](0) - assert(topicDistribution.size === k) - assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", lda.getTopicDistributionCol) { + case Row(_, topicDistribution: Vector) => + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) } // logLikelihood, logPerplexity @@ -252,6 +245,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.allParamSettings, checkModelData) + + // Make sure the result is deterministic after saving and loading the model + val model = lda.fit(dataset) + val model2 = testDefaultReadWrite(model) + assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6) + assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6) } test("read/write DistributedLDAModel") { @@ -286,7 +285,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead // There should be 1 checkpoint remaining. assert(model.getCheckpointFiles.length === 1) val checkpointFile = new Path(model.getCheckpointFiles.head) - val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = checkpointFile.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.exists(checkpointFile)) model.deleteCheckpointFiles() assert(model.getCheckpointFiles.isEmpty) @@ -323,4 +322,21 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getOptimizer === optimizer) } } + + test("LDA with Array input") { + def trainAndLogLikelihoodAndPerplexity(dataset: Dataset[_]): (Double, Double) = { + val model = new LDA().setK(k).setOptimizer("online").setMaxIter(1).setSeed(1).fit(dataset) + (model.logLikelihood(dataset), model.logPerplexity(dataset)) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val (ll, lp) = trainAndLogLikelihoodAndPerplexity(newDataset) + val (llD, lpD) = trainAndLogLikelihoodAndPerplexity(newDatasetD) + val (llF, lpF) = trainAndLogLikelihoodAndPerplexity(newDatasetF) + // TODO: need to compare the results once we fix the seed issue for LDA (SPARK-22210) + assert(llD <= 0.0 && llD != Double.NegativeInfinity) + assert(llF <= 0.0 && llF != Double.NegativeInfinity) + assert(lpD >= 0.0 && lpD != Double.NegativeInfinity) + assert(lpF >= 0.0 && lpF != Double.NegativeInfinity) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baff..55b460f1a4524 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -22,14 +22,16 @@ import scala.collection.mutable import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Dataset[_] = _ final val r1 = 1.0 final val n1 = 10 @@ -48,10 +50,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite assert(pic.getK === 2) assert(pic.getMaxIter === 20) assert(pic.getInitMode === "random") - assert(pic.getPredictionCol === "prediction") - assert(pic.getIdCol === "id") - assert(pic.getNeighborsCol === "neighbors") - assert(pic.getSimilaritiesCol === "similarities") + assert(pic.getSrcCol === "src") + assert(pic.getDstCol === "dst") + assert(!pic.isDefined(pic.weightCol)) } test("parameter validation") { @@ -62,125 +63,110 @@ class PowerIterationClusteringSuite extends SparkFunSuite new PowerIterationClustering().setInitMode("no_such_a_mode") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setIdCol("") - } - intercept[IllegalArgumentException] { - new PowerIterationClustering().setNeighborsCol("") + new PowerIterationClustering().setSrcCol("") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setSimilaritiesCol("") + new PowerIterationClustering().setDstCol("") } } test("power iteration clustering") { val n = n1 + n2 - val model = new PowerIterationClustering() + val assignments = new PowerIterationClustering() .setK(2) .setMaxIter(40) - val result = model.transform(data) + .setWeightCol("weight") + .assignClusters(data) + .select("id", "cluster") + .as[(Long, Int)] + .collect() val predictions = Array.fill(2)(mutable.Set.empty[Long]) - result.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions(cluster) += id + assignments.foreach { + case (id, cluster) => predictions(cluster) += id } - assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) - val result2 = new PowerIterationClustering() + val assignments2 = new PowerIterationClustering() .setK(2) .setMaxIter(10) .setInitMode("degree") - .transform(data) + .setWeightCol("weight") + .assignClusters(data) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) - result2.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions2(cluster) += id + assignments2.foreach { + case (id, cluster) => predictions2(cluster) += id } - assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) } test("supported input types") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setK(2) .setMaxIter(1) + .setWeightCol("weight") - def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + def runTest(srcType: DataType, dstType: DataType, weightType: DataType): Unit = { val typedData = data.select( - col("id").cast(idType).alias("id"), - col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), - col("similarities").cast(ArrayType(similarityType, containsNull = false)) - .alias("similarities") + col("src").cast(srcType).alias("src"), + col("dst").cast(dstType).alias("dst"), + col("weight").cast(weightType).alias("weight") ) - model.transform(typedData).collect() - } - - for (idType <- Seq(IntegerType, LongType)) { - runTest(idType, LongType, DoubleType) - } - for (neighborType <- Seq(IntegerType, LongType)) { - runTest(LongType, neighborType, DoubleType) - } - for (similarityType <- Seq(FloatType, DoubleType)) { - runTest(LongType, LongType, similarityType) + pic.assignClusters(typedData).collect() } - } - test("invalid input: wrong types") { - val model = new PowerIterationClustering() - .setK(2) - .setMaxIter(1) - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id").cast(DoubleType).alias("id"), - col("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (srcType <- Seq(IntegerType, LongType)) { + runTest(srcType, LongType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (dstType <- Seq(IntegerType, LongType)) { + runTest(LongType, dstType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors"), - col("neighbors").alias("similarities") - ) - model.transform(typedData) + for (weightType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, weightType) } } test("invalid input: negative similarity") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setMaxIter(1) + .setWeightCol("weight") val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(-1.0)), - (1, Array(0), Array(-1.0)) - )).toDF("id", "neighbors", "similarities") + (0, 1, -1.0), + (1, 0, -1.0) + )).toDF("src", "dst", "weight") val msg = intercept[SparkException] { - model.transform(badData) + pic.assignClusters(badData) }.getCause.getMessage assert(msg.contains("Similarity must be nonnegative")) } - test("invalid input: mismatched lengths for neighbor and similarity arrays") { - val model = new PowerIterationClustering() - .setMaxIter(1) - val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(0.5)), - (1, Array(0, 2), Array(0.5)), - (2, Array(1), Array(0.5)) - )).toDF("id", "neighbors", "similarities") - val msg = intercept[SparkException] { - model.transform(badData) - }.getCause.getMessage - assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + - "the neighbor similarity list.")) - assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + test("test default weight") { + val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst) + + val assignments = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithoutWeight) + val localAssignments = assignments + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0)) + + val assignments2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithWeightOne) + val localAssignments2 = assignments2 + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + assert(localAssignments === localAssignments2) } test("read/write") { @@ -188,10 +174,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setK(4) .setMaxIter(100) .setInitMode("degree") - .setIdCol("test_id") - .setNeighborsCol("myNeighborsCol") - .setSimilaritiesCol("mySimilaritiesCol") - .setPredictionCol("test_prediction") + .setSrcCol("src1") + .setDstCol("dst1") + .setWeightCol("weight") testDefaultReadWrite(t) } } @@ -222,17 +207,13 @@ object PowerIterationClusteringSuite { val n = n1 + n2 val points = genCircle(r1, n1) ++ genCircle(r2, n2) - val rows = for (i <- 1 until n) yield { - val neighbors = for (j <- 0 until i) yield { - j.toLong + val rows = (for (i <- 1 until n) yield { + for (j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) } - val similarities = for (j <- 0 until i) yield { - sim(points(i), points(j)) - } - (i.toLong, neighbors.toArray, similarities.toArray) - } + }).flatMap(_.iterator) - spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + spark.createDataFrame(rows).toDF("src", "dst", "weight") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index ede284712b1c0..2b0909acf69c3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -67,8 +67,8 @@ class BinaryClassificationEvaluatorSuite evaluator.evaluate(stringDF) } assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + - "equal to one of the following types: [DoubleType, ") - assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + "equal to one of the following types: [double, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.") } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 2c175ff68e0b8..e2d77560293fa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -33,10 +33,17 @@ class ClusteringEvaluatorSuite import testImplicits._ @transient var irisDataset: Dataset[_] = _ + @transient var newIrisDataset: Dataset[_] = _ + @transient var newIrisDatasetD: Dataset[_] = _ + @transient var newIrisDatasetF: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt") + val datasets = MLTestingUtils.generateArrayFeatureDataset(irisDataset) + newIrisDataset = datasets._1 + newIrisDatasetD = datasets._2 + newIrisDatasetF = datasets._3 } test("params") { @@ -66,6 +73,9 @@ class ClusteringEvaluatorSuite .setPredictionCol("label") assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.6564679231 relTol 1e-5) } /* @@ -85,6 +95,9 @@ class ClusteringEvaluatorSuite .setDistanceMeasure("cosine") assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.7222369298 relTol 1e-5) } test("number of clusters must be greater than one") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 61217669d9277..bca580d411373 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -289,4 +289,20 @@ class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { val newInstance = testDefaultReadWrite(instance) assert(newInstance.vocabulary === instance.vocabulary) } + + test("SPARK-22974: CountVectorModel should attach proper attribute to output column") { + val df = spark.createDataFrame(Seq( + (0, 1.0, Array("a", "b", "c")), + (1, 2.0, Array("a", "b", "b", "c", "a", "d")) + )).toDF("id", "features1", "words") + + val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features2") + + val df1 = cvm.transform(df) + val interaction = new Interaction().setInputCols(Array("features1", "features2")) + .setOutputCol("features") + interaction.transform(df1) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index a250331efeb1d..0de6528c4cf22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -105,7 +105,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { testTransformerByInterceptingException[(Int, Boolean)]( original, model, - "Label column already exists and is not of type NumericType.", + "Label column already exists and is not of type numeric.", "x") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 21259a50916d2..20972d1f403b9 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,6 +65,57 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { testStopWordsRemover(remover, dataSet) } + test("StopWordsRemover with localed input (case insensitive)") { + val stopWords = Array("milk", "cookie") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setCaseSensitive(false) + .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's. + val dataSet = Seq( + // scalastyle:off + (Seq("mİlk", "and", "nuts"), Seq("and", "nuts")), + // scalastyle:on + (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with localed input (case sensitive)") { + val stopWords = Array("milk", "cookie") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setCaseSensitive(true) + .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's. + val dataSet = Seq( + // scalastyle:off + (Seq("mİlk", "and", "nuts"), Seq("mİlk", "and", "nuts")), + // scalastyle:on + (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with invalid locale") { + intercept[IllegalArgumentException] { + val stopWords = Array("test", "a", "an", "the") + new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setLocale("rt") // invalid locale + } + } + test("StopWordsRemover case sensitive") { val remover = new StopWordsRemover() .setInputCol("raw") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 91fb24a268b8c..ed15a1d88a269 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -99,9 +99,9 @@ class VectorAssemblerSuite assembler.transform(df) } assert(thrown.getMessage contains - "Data type StringType of column a is not supported.\n" + - "Data type StringType of column b is not supported.\n" + - "Data type StringType of column c is not supported.") + "Data type string of column a is not supported.\n" + + "Data type string of column b is not supported.\n" + + "Data type string of column c is not supported.") } test("ML attributes") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala new file mode 100644 index 0000000000000..2252151af306b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.ml.util.MLTest +import org.apache.spark.sql.DataFrame + +class PrefixSpanSuite extends MLTest { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("PrefixSpan projections with multiple partial starts") { + val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") + val result = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(smallDataset) + .as[(Seq[Seq[Int]], Long)].collect() + val expected = Array( + (Seq(Seq(1)), 1L), + (Seq(Seq(1, 2)), 1L), + (Seq(Seq(1), Seq(1)), 1L), + (Seq(Seq(1), Seq(2)), 1L), + (Seq(Seq(1), Seq(3)), 1L), + (Seq(Seq(1, 3)), 1L), + (Seq(Seq(2)), 1L), + (Seq(Seq(2, 3)), 1L), + (Seq(Seq(2), Seq(1)), 1L), + (Seq(Seq(2), Seq(2)), 1L), + (Seq(Seq(2), Seq(3)), 1L), + (Seq(Seq(3)), 1L)) + compareResults[Int](expected, result) + } + + /* + To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 2 1 2 + 1 2 1 3 + 2 1 1 1 + 2 2 2 3 2 + 2 3 2 1 2 + 3 1 2 1 2 + 3 2 1 5 + 4 1 1 6 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 0.75 + 2 <{2}> 0.75 + 3 <{3}> 0.50 + 4 <{1},{3}> 0.50 + 5 <{1,2}> 0.75 + */ + val smallTestData = Seq( + Seq(Seq(1, 2), Seq(3)), + Seq(Seq(1), Seq(3, 2), Seq(1, 2)), + Seq(Seq(1, 2), Seq(5)), + Seq(Seq(6))) + + val smallTestDataExpectedResult = Array( + (Seq(Seq(1)), 3L), + (Seq(Seq(2)), 3L), + (Seq(Seq(3)), 2L), + (Seq(Seq(1), Seq(3)), 2L), + (Seq(Seq(1, 2)), 3L) + ) + + test("PrefixSpan Integer type, variable-size itemsets") { + val df = smallTestData.toDF("sequence") + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan input row with nulls") { + val df = (smallTestData :+ null).toDF("sequence") + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan String type, variable-size itemsets") { + val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap + val df = smallTestData + .map(seq => seq.map(itemSet => itemSet.map(intToString))) + .toDF("sequence") + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) + .as[(Seq[Seq[String]], Long)].collect() + + val expected = smallTestDataExpectedResult.map { case (seq, freq) => + (seq.map(itemSet => itemSet.map(intToString)), freq) + } + compareResults[String](expected, result) + } + + private def compareResults[Item]( + expectedValue: Array[(Seq[Seq[Item]], Long)], + actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = { + val expectedSet = expectedValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + val actualSet = actualValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + assert(expectedSet === actualSet) + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e3dfe2faf5698..9a59c41740daf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -594,11 +594,12 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { (check: (ALSModel, ALSModel) => Unit) (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { val dfs = genRatingsDFWithNumericCols(spark, column) - val df = dfs.find { - case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType - } match { - case Some((_, df)) => df + val maybeDf = dfs.find { case (numericTypeWithEncoder, _) => + numericTypeWithEncoder.numericType == baseType } + assert(maybeDf.isDefined) + val df = maybeDf.get._2 + val expected = estimator.fit(df) val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) actuals.foreach { case (_, actual) => check(expected, actual) } @@ -612,7 +613,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { estimator.fit(strDF) } assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) + s"$column must be of type numeric but was actually of type string")) } private class NumericTypeWithEncoder[A](val numericType: NumericType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 4e4ff71c9de90..6cc73e040e82c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -385,7 +385,7 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { aft.fit(dfWithStringCensors) } assert(thrown.getMessage.contains( - "Column censor must be of type NumericType but was actually of type StringType")) + "Column censor must be of type numeric but was actually of type string")) } test("numerical stability of standardization") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index fad11d078250f..b145c7a3dc952 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -20,13 +20,15 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -201,9 +203,81 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } + test("model evaluateEachIteration") { + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType(lossType) + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTRegressionModel("gbt-reg-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures) + val model2 = new GBTRegressionModel("gbt-reg-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures) + + for (evalLossType <- GBTRegressor.supportedLossTypes) { + val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType) + val lossErr1 = GradientBoostedTrees.computeError(validationData, + model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType)) + val lossErr2 = GradientBoostedTrees.computeError(validationData, + model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType)) + val lossErr3 = GradientBoostedTrees.computeError(validationData, + model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType)) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + } + } + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) - ///////////////////////////////////////////////////////////////////////////// + val numIter = 20 + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val errorWithoutValidation = GradientBoostedTrees.computeError(validationData, + modelWithoutValidation.trees, modelWithoutValidation.treeWeights, + modelWithoutValidation.getOldLossType) + val errorWithValidation = GradientBoostedTrees.computeError(validationData, + modelWithValidation.trees, modelWithValidation.treeWeights, + modelWithValidation.getOldLossType) + + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Regression) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index d5bcbb221783e..997c50157dcda 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -493,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest } [1] -0.0457441 -0.6833928 [1] 1.8121235 -0.1747493 -0.5815417 + + R code for deivance calculation: + data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1)) + summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance + [1] 3.70055 + summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance + [1] 3.809296 */ val expected = Seq( Vectors.dense(0.0, -0.0457441, -0.6833928), Vectors.dense(1.8121235, -0.1747493, -0.5815417)) + val residualDeviancesR = Array(3.809296, 3.70055) + import GeneralizedLinearRegression._ var idx = 0 @@ -510,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept (with zero values).") + assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3) idx += 1 } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 15dade2627090..e6ee7220d2279 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class CrossValidatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -66,6 +66,13 @@ class CrossValidatorSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) + + val result = cvModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("cross validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9024342d9c831..cd76acf9c67bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -64,6 +64,13 @@ class TrainValidationSplitSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(tvsModel.validationMetrics.length === lrParamMaps.length) + + val result = tvsModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("train validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4da95e74434ee..4d9e664850c12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.util import java.io.{File, IOException} +import org.json4s.JNothing import org.scalatest.Suite -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val shouldNotSetIfSetintParamWithDefault: IntParam = + new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") @@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable { set(doubleArrayParam -> Array(8.0, 9.0)) set(stringArrayParam -> Array("10", "11")) + def checkExclusiveParams(): Unit = { + if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) { + throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " + + "shouldn't be set at the same time") + } + } + override def copy(extra: ParamMap): Params = defaultCopy(extra) override def write: MLWriter = new DefaultParamsWriter(this) @@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext val myParams = new MyParams("my_params") testDefaultReadWrite(myParams) } + + test("default param shouldn't become user-supplied param after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1) + myParams.checkExclusiveParams() + val loadedMyParams = testDefaultReadWrite(myParams) + loadedMyParams.checkExclusiveParams() + assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) == + myParams.getDefault(myParams.intParamWithDefault)) + + loadedMyParams.set(myParams.intParamWithDefault, 1) + intercept[SparkException] { + loadedMyParams.checkExclusiveParams() + } + } + + test("User-supplied value for default param should be kept after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.intParamWithDefault, 100) + val loadedMyParams = testDefaultReadWrite(myParams) + assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100) + } + + test("Read metadata without default field prior to 2.4") { + // default params are saved in `paramMap` field in metadata file prior to Spark 2.4. + val metadata = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.3.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata = DefaultParamsReader.parseMetadata(metadata) + val myParams = new MyParams("my_params") + assert(!myParams.isSet(myParams.intParamWithDefault)) + parsedMetadata.getAndSetParams(myParams) + + // The behavior prior to Spark 2.4, default params are set in loaded ML instance. + assert(myParams.isSet(myParams.intParamWithDefault)) + } + + test("Should raise error when read metadata without default field after Spark 2.4") { + val myParams = new MyParams("my_params") + + val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.4.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1) + val err1 = intercept[IllegalArgumentException] { + parsedMetadata1.getAndSetParams(myParams) + } + assert(err1.getMessage().contains("Cannot recognize JSON metadata")) + + val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"3.0.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2) + val err2 = intercept[IllegalArgumentException] { + parsedMetadata2.getAndSetParams(myParams) + } + assert(err2.getMessage().contains("Cannot recognize JSON metadata")) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index c328d81b4bc3a..91a8b14625a86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.recommendation.{ALS, ALSModel} @@ -74,7 +74,7 @@ object MLTestingUtils extends SparkFunSuite { estimator.fit(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) estimator match { case weighted: Estimator[M] with HasWeightCol => @@ -86,7 +86,7 @@ object MLTestingUtils extends SparkFunSuite { weighted.fit(dfWithStringWeights) } assert(thrown.getMessage.contains( - "Column weight must be of type NumericType but was actually of type StringType")) + "Column weight must be of type numeric but was actually of type string")) case _ => } } @@ -104,7 +104,7 @@ object MLTestingUtils extends SparkFunSuite { evaluator.evaluate(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) } def genClassifDFWithNumericLabelCol( @@ -247,4 +247,25 @@ object MLTestingUtils extends SparkFunSuite { } models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} } + + /** + * Helper function for testing different input types for "features" column. Given a DataFrame, + * generate three output DataFrames: one having vector "features" column with float precision, + * one having double array "features" column with float precision, and one having float array + * "features" column. + */ + def generateArrayFeatureDataset(dataset: Dataset[_], + featuresColName: String = "features"): (Dataset[_], Dataset[_], Dataset[_]) = { + val toFloatVectorUDF = udf { (features: Vector) => + Vectors.dense(features.toArray.map(_.toFloat.toDouble))} + val toDoubleArrayUDF = udf { (features: Vector) => features.toArray} + val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)} + val newDataset = dataset.withColumn(featuresColName, toFloatVectorUDF(col(featuresColName))) + val newDatasetD = newDataset.withColumn(featuresColName, toDoubleArrayUDF(col(featuresColName))) + val newDatasetF = newDataset.withColumn(featuresColName, toFloatArrayUDF(col(featuresColName))) + assert(newDataset.schema(featuresColName).dataType.equals(new VectorUDT)) + assert(newDatasetD.schema(featuresColName).dataType.equals(new ArrayType(DoubleType, false))) + assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) + (newDataset, newDatasetD, newDatasetF) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c8ed057a516a..5ed9d077afe78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -72,6 +72,27 @@ class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkCon } } + test("invalid user and product") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + + intercept[IllegalArgumentException] { + // invalid user + model.predict(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.predict(0, 5) + } + intercept[IllegalArgumentException] { + // invalid user + model.recommendProducts(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.recommendUsers(5, 2) + } + } + test("batch predict API recommendProductsForUsers") { val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) val topK = 10 diff --git a/pom.xml b/pom.xml index 0a711f287a53f..6988c65348652 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + external/avro @@ -113,7 +114,7 @@ 1.8 ${java.version} ${java.version} - 3.3.9 + 3.5.4 spark 1.7.16 1.2.17 @@ -129,33 +130,32 @@ 1.2.1 10.12.1.1 - 1.8.2 - 1.4.3 + 1.10.0 + 1.5.2 nohive 1.6.0 - 9.3.20.v20170531 + 9.3.24.v20180605 3.1.0 0.8.4 2.4.0 2.0.8 3.1.5 - 1.7.7 + 1.8.2 hadoop2 - 0.9.4 - 1.7.3 + 1.8.10 - 1.11.76 + 1.11.271 - 0.10.2 + 0.12.8 - 4.5.4 - 4.4.8 + 4.5.6 + 4.4.10 3.1 3.4.1 3.2.2 - 2.11.8 + 2.11.12 2.11 1.9.13 2.6.7 @@ -170,7 +170,7 @@ 3.5 3.2.10 - 3.0.8 + 3.0.9 2.22.2 2.9.3 3.5.2 @@ -189,10 +189,11 @@ If you are changing Arrow version specification, please check ./python/pyspark/sql/utils.py, ./python/run-tests.py and ./python/setup.py too. --> - 0.8.0 + 0.10.0 ${java.home} + org.spark_project @@ -313,13 +314,13 @@ chill-java ${chill.version}
      - org.apache.xbean - xbean-asm5-shaded - 4.4 + xbean-asm6-shaded + 4.8 jline jline - 2.12.1 + 2.14.6 org.scalatest @@ -760,6 +760,12 @@ 1.10.19 test + + org.jmock + jmock-junit4 + test + 2.8.4 + org.scalacheck scalacheck_${scala.binary.version} @@ -904,6 +910,10 @@ com.sun.jersey.contribs * + + net.java.dev.jets3t + jets3t + @@ -977,24 +987,15 @@ - + - net.java.dev.jets3t - jets3t - ${jets3t.version} + javax.activation + activation + 1.1.1 ${hadoop.deps.scope} - - - commons-logging - commons-logging - - - - - org.bouncycastle - bcprov-jdk15on - - 1.58 org.apache.hadoop @@ -1736,6 +1737,10 @@ org.apache.hadoop hadoop-common + + org.apache.hadoop + hadoop-hdfs + org.apache.hive hive-storage-api @@ -1778,6 +1783,12 @@ parquet-hadoop ${parquet.version} ${parquet.deps.scope} + + + commons-pool + commons-pool + + org.apache.parquet @@ -2110,7 +2121,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.20.1 + 2.22.0 @@ -2149,6 +2160,7 @@ false ${test.exclude.tags} + ${test.include.tags} @@ -2196,6 +2208,7 @@ __not_used__ ${test.exclude.tags} + ${test.include.tags} @@ -2666,11 +2679,20 @@ hadoop-2.7 - 2.7.3 + 2.7.7 2.7.1 + + hadoop-3.1 + + 3.1.0 + 2.12.0 + 3.4.9 + + + yarn @@ -2690,6 +2712,7 @@ kubernetes resource-managers/kubernetes/core + resource-managers/kubernetes/integration-tests @@ -2733,7 +2756,7 @@ scala-2.12 - 2.12.4 + 2.12.6 2.12 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a87fa68422c34..4f250c9943edb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,25 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12 + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), + + // [SPARK-24296][CORE] Replicate large blocks as a stream. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"), + // [SPARK-23528] Add numIter to ClusteringSummary + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), + // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), + + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"), + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), @@ -62,12 +81,32 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"), + + // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + + // [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter") ) // Exclude rules for 2.3.x diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7469f11df0294..1f45a06084c0d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -27,6 +27,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.etsy.sbt.checkstyle.CheckstylePlugin.autoImport._ import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import com.typesafe.tools.mima.plugin.MimaKeys @@ -39,8 +40,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, avro) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "avro" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -56,11 +57,11 @@ object BuildCommons { val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, streamingFlumeSink, streamingFlume, streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, - dockerIntegrationTests, hadoopCloud) = + dockerIntegrationTests, hadoopCloud, kubernetesIntegrationTests) = Seq("kubernetes", "mesos", "yarn", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", - "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) + "docker-integration-tests", "hadoop-cloud", "kubernetes-integration-tests").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") @@ -211,7 +212,7 @@ object SparkBuild extends PomBuild { .map(file), incOptions := incOptions.value.withNameHashing(true), publishMavenStyle := true, - unidocGenjavadocVersion := "0.10", + unidocGenjavadocVersion := "0.11", // Override SBT's default resolvers: resolvers := Seq( @@ -317,7 +318,7 @@ object SparkBuild extends PomBuild { /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ - ExcludedDependencies.settings)) + ExcludedDependencies.settings ++ Checkstyle.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -325,7 +326,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010, kvstore + unsafe, tags, sqlKafka010, kvstore, avro ).contains(x) } @@ -463,7 +464,8 @@ object DockerIntegrationTests { */ object DependencyOverrides { lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1", + dependencyOverrides += "jline" % "jline" % "2.14.6") } /** @@ -686,9 +688,11 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) @@ -728,7 +732,8 @@ object Unidoc { scalacOptions in (ScalaUnidoc, unidoc) ++= Seq( "-groups", // Group similar methods together based on the @group annotation. - "-skip-packages", "org.apache.hadoop" + "-skip-packages", "org.apache.hadoop", + "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath ) ++ ( // Add links to sources when generating Scaladoc for a non-snapshot release if (!isSnapshot.value) { @@ -740,6 +745,17 @@ object Unidoc { ) } +object Checkstyle { + lazy val settings = Seq( + checkstyleSeverityLevel := Some(CheckstyleSeverityLevel.Error), + javaSource in (Compile, checkstyle) := baseDirectory.value / "src/main/java", + javaSource in (Test, checkstyle) := baseDirectory.value / "src/test/java", + checkstyleConfigLocation := CheckstyleConfigLocation.File("dev/checkstyle.xml"), + checkstyleOutputFile := baseDirectory.value / "target/checkstyle-output.xml", + checkstyleOutputFile in Test := baseDirectory.value / "target/checkstyle-output.xml" + ) +} + object CopyDependencies { val copyDeps = TaskKey[Unit]("copyDeps", "Copies needed dependencies to the build directory.") diff --git a/project/build.properties b/project/build.properties index b19518fd7aa1c..d03985d980ec8 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.16 +sbt.version=0.13.17 diff --git a/project/plugins.sbt b/project/plugins.sbt index 96bdb9067ae59..ffbd417b0f145 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,11 @@ +addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") + +// sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. +libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2" + +// checkstyle uses guava 23.0. +libraryDependencies += "com.google.guava" % "guava" % "23.0" + // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") diff --git a/python/README.md b/python/README.md index 2e0112da58b94..c020d84b01ffd 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). +At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/docs/Makefile b/python/docs/Makefile index 09898f29950ed..1ed1f33af2326 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -1,18 +1,43 @@ # Makefile for Sphinx documentation # +ifndef SPHINXBUILD +ifndef SPHINXPYTHON +SPHINXBUILD = sphinx-build +endif +endif + +ifdef SPHINXBUILD +# User-friendly check for sphinx-build. +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +else +# Note that there is an issue with Python version and Sphinx in PySpark documentation generation. +# Please remove this check below when this issue is fixed. See SPARK-24530 for more details. +PYTHON_VERSION_CHECK = $(shell $(SPHINXPYTHON) -c 'import sys; print(sys.version_info < (3, 0, 0))') +ifeq ($(PYTHON_VERSION_CHECK), True) +$(error Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.) +endif +# Check if Sphinx is installed. +ifeq ($(shell $(SPHINXPYTHON) -c 'import sphinx' >/dev/null 2>&1; echo $$?), 1) +$(error Python executable '$(SPHINXPYTHON)' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +# Use 'SPHINXPYTHON -msphinx' instead of 'sphinx-build'. See https://github.com/sphinx-doc/sphinx/pull/3523 for more details. +SPHINXBUILD = $(SPHINXPYTHON) -msphinx +endif + # You can set these variables from the command line. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build +# You can set SPHINXBUILD to specify Sphinx build executable or SPHINXPYTHON to specify the Python executable used in Sphinx. +# They follow: +# 1. if SPHINXPYTHON is set, use Python. If SPHINXBUILD is set, use sphinx-build. +# 2. If both are set, SPHINXBUILD has a higher priority over SPHINXPYTHON +# 3. By default, SPHINXBUILD is used as 'sphinx-build'. -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip) - -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip deleted file mode 100644 index 2f8edcc0c0b88..0000000000000 Binary files a/python/lib/py4j-0.10.6-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip new file mode 100644 index 0000000000000..128e321078793 Binary files /dev/null and b/python/lib/py4j-0.10.7-src.zip differ diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index f730d290273fe..30ad04297c682 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -227,20 +227,49 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): def handle(self): from pyspark.accumulators import _accumulatorRegistry - while not self.server.server_shutdown: - # Poll every 1 second for new data -- don't block in case of shutdown. - r, _, _ = select.select([self.rfile], [], [], 1) - if self.rfile in r: - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + auth_token = self.server.auth_token + + def poll(func): + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + if func(): + break + + def accum_updates(): + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + return False + + def authenticate_and_accum_updates(): + received_token = self.rfile.read(len(auth_token)) + if isinstance(received_token, bytes): + received_token = received_token.decode("utf-8") + if (received_token == auth_token): + accum_updates() + # we've authenticated, we can break out of the first loop now + return True + else: + raise Exception( + "The value of the provided token to the AccumulatorServer is not correct.") + + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated, don't need to check for the token anymore + poll(accum_updates) class AccumulatorServer(SocketServer.TCPServer): + def __init__(self, server_address, RequestHandlerClass, auth_token): + SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) + self.auth_token = auth_token + """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. @@ -253,9 +282,9 @@ def shutdown(self): self.server_close() -def _start_update_server(): +def _start_update_server(auth_token): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index ea845b98b3db2..88519d7311fcc 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -272,7 +272,7 @@ def save_memoryview(self, obj): if not PY3: def save_buffer(self, obj): self.save(str(obj)) - dispatch[buffer] = save_buffer + dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3 def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) @@ -801,10 +801,10 @@ def save_ellipsis(self, obj): def save_not_implemented(self, obj): self.save_reduce(_gen_not_implemented, ()) - if PY3: - dispatch[io.TextIOWrapper] = save_file - else: + try: # Python 2 dispatch[file] = save_file + except NameError: # Python 3 + dispatch[io.TextIOWrapper] = save_file dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented @@ -819,6 +819,11 @@ def save_logger(self, obj): dispatch[logging.Logger] = save_logger + def save_root_logger(self, obj): + self.save_reduce(logging.getLogger, (), obj=obj) + + dispatch[logging.RootLogger] = save_root_logger + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 7c664966ed74e..4cabae4b2f50b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -126,7 +126,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment = environment or {} # java gateway must have been launched at this point. if conf is not None and conf._jconf is not None: - # conf has been initialized in JVM properly, so use conf directly. This represent the + # conf has been initialized in JVM properly, so use conf directly. This represents the # scenario that JVM has been launched before SparkConf is created (e.g. SparkContext is # created and then stopped, and we create a new SparkConf and new SparkContext again) self._conf = conf @@ -183,9 +183,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server - self._accumulatorServer = accumulators._start_update_server() + auth_token = self._gateway.gateway_parameters.auth_token + self._accumulatorServer = accumulators._start_update_server(auth_token) (host, port) = self._accumulatorServer.server_address - self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port) + self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) self._jsc.sc().register(self._javaAccumulator) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') @@ -211,9 +212,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: - self._python_includes.append(filename) - sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + try: + filepath = os.path.join(SparkFiles.getRootDirectory(), filename) + if not os.path.exists(filepath): + # In case of YARN with shell mode, 'spark.submit.pyFiles' files are + # not added via SparkContext.addFile. Here we check if the file exists, + # try to copy and then add it to the path. See SPARK-21945. + shutil.copyfile(path, filepath) + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: + self._python_includes.append(filename) + sys.path.insert(1, filepath) + except Exception: + warnings.warn( + "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to " + "Python path:\n %s" % (path, "\n ".join(sys.path)), + RuntimeWarning) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) @@ -481,10 +494,14 @@ def f(split, iterator): c = list(c) # Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - jrdd = self._serialize_to_jvm(c, numSlices, serializer) + + def reader_func(temp_filename): + return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + + jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) - def _serialize_to_jvm(self, data, parallelism, serializer): + def _serialize_to_jvm(self, data, serializer, reader_func): """ Calling the Java parallelize() method with an ArrayList is too slow, because it sends O(n) Py4J commands. As an alternative, serialized @@ -494,8 +511,7 @@ def _serialize_to_jvm(self, data, parallelism, serializer): try: serializer.dump_stream(data, tempFile) tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - return readRDDFromFile(self._jsc, tempFile.name, parallelism) + return reader_func(tempFile.name) finally: # readRDDFromFile eagerily reads the file so we can delete right after. os.unlink(tempFile.name) @@ -835,6 +851,8 @@ def addFile(self, path, recursive=False): A directory can be given if the recursive option is set to True. Currently directories are only supported for Hadoop-supported filesystems. + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. + >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: @@ -855,6 +873,8 @@ def addPyFile(self, path): SparkContext in the future. The C{path} passed can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, HTTPS or FTP URI. + + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix @@ -917,10 +937,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): >>> def stop_job(): ... sleep(5) ... sc.cancelJobGroup("job_to_cancel") - >>> supress = lock.acquire() - >>> supress = threading.Thread(target=start_job, args=(10,)).start() - >>> supress = threading.Thread(target=stop_job).start() - >>> supress = lock.acquire() + >>> suppress = lock.acquire() + >>> suppress = threading.Thread(target=start_job, args=(10,)).start() + >>> suppress = threading.Thread(target=stop_job).start() + >>> suppress = lock.acquire() >>> print(result) Cancelled @@ -998,8 +1018,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) - return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) + sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) + return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7bed5216eabf3..ebdd665e349c5 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -29,7 +29,7 @@ from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main -from pyspark.serializers import read_int, write_int +from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer def compute_real_exit_code(exit_code): @@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code): return 1 -def worker(sock): +def worker(sock, authenticated): """ Called by a worker process after the fork(). """ @@ -56,6 +56,18 @@ def worker(sock): # otherwise writes also cause a seek that makes us miss data on the read side. infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) + + if not authenticated: + client_secret = UTF8Deserializer().loads(infile) + if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret: + write_with_length("ok".encode("utf-8"), outfile) + outfile.flush() + else: + write_with_length("err".encode("utf-8"), outfile) + outfile.flush() + sock.close() + return 1 + exit_code = 0 try: worker_main(infile, outfile) @@ -153,8 +165,11 @@ def handle_sigterm(*args): write_int(os.getpid(), outfile) outfile.flush() outfile.close() + authenticated = False while True: - code = worker(sock) + code = worker(sock, authenticated) + if code == 0: + authenticated = True if not reuse or code: # wait for closing try: diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 9cf0e8c8d2fe9..9c4ed46598632 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -27,7 +27,7 @@ def _find_spark_home(): """Find the SPARK_HOME.""" - # If the enviroment has SPARK_HOME set trust it. + # If the environment has SPARK_HOME set trust it. if "SPARK_HOME" in os.environ: return os.environ["SPARK_HOME"] diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index 6af084adcf373..37a2914ebac05 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -710,7 +710,7 @@ def merge(iterables, key=None, reverse=False): # value seen being in the 100 most extreme values is 100/101. # * If the value is a new extreme value, the cost of inserting it into the # heap is 1 + log(k, 2). -# * The probabilty times the cost gives: +# * The probability times the cost gives: # (k/i) * (1 + log(k, 2)) # * Summing across the remaining n-k elements gives: # sum((k/i) * (1 + log(k, 2)) for i in range(k+1, n+1)) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3e704fe9bf6ec..b06503b53be90 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -21,16 +21,19 @@ import select import signal import shlex +import shutil import socket import platform +import tempfile +import time from subprocess import Popen, PIPE if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from pyspark.find_spark_home import _find_spark_home -from pyspark.serializers import read_int +from pyspark.serializers import read_int, write_with_length, UTF8Deserializer def launch_gateway(conf=None): @@ -41,6 +44,7 @@ def launch_gateway(conf=None): """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) + gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the @@ -59,40 +63,40 @@ def launch_gateway(conf=None): ]) command = command + shlex.split(submit_args) - # Start a socket that will be used by PythonGatewayServer to communicate its port to us - callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - callback_socket.bind(('127.0.0.1', 0)) - callback_socket.listen(1) - callback_host, callback_port = callback_socket.getsockname() - env = dict(os.environ) - env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host - env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) - - # Launch the Java gateway. - # We open a pipe to stdin so that the Java gateway can die when the pipe is broken - if not on_windows: - # Don't send ctrl-c / SIGINT to the Java gateway: - def preexec_func(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) - else: - # preexec_fn not supported on Windows - proc = Popen(command, stdin=PIPE, env=env) - - gateway_port = None - # We use select() here in order to avoid blocking indefinitely if the subprocess dies - # before connecting - while gateway_port is None and proc.poll() is None: - timeout = 1 # (seconds) - readable, _, _ = select.select([callback_socket], [], [], timeout) - if callback_socket in readable: - gateway_connection = callback_socket.accept()[0] - # Determine which ephemeral port the server started on: - gateway_port = read_int(gateway_connection.makefile(mode="rb")) - gateway_connection.close() - callback_socket.close() - if gateway_port is None: - raise Exception("Java gateway process exited before sending the driver its port number") + # Create a temporary directory where the gateway server should write the connection + # information. + conn_info_dir = tempfile.mkdtemp() + try: + fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) + os.close(fd) + os.unlink(conn_info_file) + + env = dict(os.environ) + env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken + if not on_windows: + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_func(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) + else: + # preexec_fn not supported on Windows + proc = Popen(command, stdin=PIPE, env=env) + + # Wait for the file to appear, or for the process to exit, whichever happens first. + while not proc.poll() and not os.path.isfile(conn_info_file): + time.sleep(0.1) + + if not os.path.isfile(conn_info_file): + raise Exception("Java gateway process exited before sending its port number") + + with open(conn_info_file, "rb") as info: + gateway_port = read_int(info) + gateway_secret = UTF8Deserializer().loads(info) + finally: + shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -111,7 +115,9 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, + auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") @@ -126,3 +132,69 @@ def killChild(): java_import(gateway.jvm, "scala.Tuple2") return gateway + + +def _do_server_auth(conn, auth_secret): + """ + Performs the authentication protocol defined by the SocketAuthHelper class on the given + file-like object 'conn'. + """ + write_with_length(auth_secret.encode("utf-8"), conn) + conn.flush() + reply = UTF8Deserializer().loads(conn) + if reply != "ok": + conn.close() + raise Exception("Unexpected reply from iterator server.") + + +def local_connect_and_auth(port, auth_secret): + """ + Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. + Handles IPV4 & IPV6, does some error handling. + :param port + :param auth_secret + :return: a tuple with (sockfile, sock) + """ + sock = None + errors = [] + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, _, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(15) + sock.connect(sa) + sockfile = sock.makefile("rwb", 65536) + _do_server_auth(sockfile, auth_secret) + return (sockfile, sock) + except socket.error as e: + emsg = _exception_message(e) + errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) + sock.close() + sock = None + else: + raise Exception("could not open socket: %s" % errors) + + +def ensure_callback_server_started(gw): + """ + Start callback server if not already started. The callback server is needed if the Java + driver process needs to callback into the Python driver process to execute Python code. + """ + + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 129d7d68f7cbb..d99a25390db15 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -21,5 +21,11 @@ """ from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer from pyspark.ml.pipeline import Pipeline, PipelineModel +from pyspark.ml import classification, clustering, evaluation, feature, fpm, \ + image, pipeline, recommendation, regression, stat, tuning, util, linalg, param -__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"] +__all__ = [ + "Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel", + "classification", "clustering", "evaluation", "feature", "fpm", "image", + "recommendation", "regression", "stat", "tuning", "util", "linalg", "param", +] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ec17653a1adf9..d5963f4f7042c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> blorModel.intercept == model2.intercept True + >>> model2 + LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2 .. versionadded:: 1.3.0 """ @@ -562,6 +564,9 @@ def evaluate(self, dataset): java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionSummary(JavaWrapper): """ @@ -1131,6 +1136,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestClassificationModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): @@ -1193,6 +1205,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) + >>> gbt.getFeatureSubsetStrategy() + 'all' >>> model = gbt.fit(td) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1222,6 +1236,12 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)], + ... ["indexed", "features"]) + >>> model.evaluateEachIteration(validation) + [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] + >>> model.numClasses + 2 .. versionadded:: 1.4.0 """ @@ -1240,19 +1260,22 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") """ super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0, + featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1261,12 +1284,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Classification. """ kwargs = self._input_kwargs @@ -1289,8 +1314,15 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + -class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, +class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. @@ -1319,6 +1351,17 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + return self._call_java("evaluateEachIteration", dataset) + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b3d5fb17f6b81..ab449bc3f8f51 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -16,17 +16,19 @@ # import sys +import warnings from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', 'KMeans', 'KMeansModel', 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', - 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] + 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel', 'PowerIterationClustering'] class ClusteringSummary(JavaWrapper): @@ -86,6 +88,14 @@ def clusterSizes(self): """ return self._call_java("clusterSizes") + @property + @since("2.4.0") + def numIter(self): + """ + Number of iterations. + """ + return self._call_java("numIter") + class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): """ @@ -302,7 +312,15 @@ class KMeansSummary(ClusteringSummary): .. versionadded:: 2.1.0 """ - pass + + @property + @since("2.4.0") + def trainingCost(self): + """ + K-means cost (sum of squared distances to the nearest centroid for all points in the + training dataset). This is equivalent to sklearn's inertia. + """ + return self._call_java("trainingCost") class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): @@ -322,7 +340,13 @@ def computeCost(self, dataset): """ Return the K-means cost (sum of squared distances of points to their nearest center) for this model on the given data. + + ..note:: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead. + You can also get the cost on the training dataset in the summary. """ + warnings.warn("Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator " + "instead. You can also get the cost on the training dataset in the summary.", + DeprecationWarning) return self._call_java("computeCost", dataset) @property @@ -348,8 +372,8 @@ def summary(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, - JavaMLWritable, JavaMLReadable): +class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, HasMaxIter, + HasTol, HasSeed, JavaMLWritable, JavaMLReadable): """ K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). @@ -378,6 +402,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol 2 >>> summary.clusterSizes [2, 2] + >>> summary.trainingCost + 2.000... >>> kmeans_path = temp_path + "/kmeans" >>> kmeans.save(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path) @@ -405,9 +431,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) - distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -543,8 +566,8 @@ def summary(self): @inherit_doc -class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, - JavaMLWritable, JavaMLReadable): +class BisectingKMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, + HasMaxIter, HasSeed, JavaMLWritable, JavaMLReadable): """ A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. @@ -584,6 +607,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> bkm2 = BisectingKMeans.load(bkm_path) >>> bkm2.getK() 2 + >>> bkm2.getDistanceMeasure() + 'euclidean' >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) @@ -606,10 +631,10 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") """ super(BisectingKMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", @@ -621,10 +646,10 @@ def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=2 @keyword_only @since("2.0.0") def setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") Sets params for BisectingKMeans. """ kwargs = self._input_kwargs @@ -658,6 +683,20 @@ def getMinDivisibleClusterSize(self): """ return self.getOrDefault(self.minDivisibleClusterSize) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + def _create_model(self, java_model): return BisectingKMeansModel(java_model) @@ -836,7 +875,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter Terminology: - - "term" = "word": an el + - "term" = "word": an element of the vocabulary - "token": instance of a term appearing in a document - "topic": multinomial distribution over terms representing some concept - "document": one piece of text, corresponding to one row in the input data @@ -938,7 +977,7 @@ def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInte k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ subsamplingRate=0.05, optimizeDocConcentration=True,\ docConcentration=None, topicConcentration=None,\ - topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) """ super(LDA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid) @@ -967,7 +1006,7 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ subsamplingRate=0.05, optimizeDocConcentration=True,\ docConcentration=None, topicConcentration=None,\ - topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) Sets params for LDA. """ @@ -996,7 +1035,7 @@ def getK(self): def setOptimizer(self, value): """ Sets the value of :py:attr:`optimizer`. - Currenlty only support 'em' and 'online'. + Currently only support 'em' and 'online'. >>> algo = LDA().setOptimizer("em") >>> algo.getOptimizer() @@ -1156,10 +1195,189 @@ def getKeepLastCheckpoint(self): return self.getOrDefault(self.keepLastCheckpoint) +@inherit_doc +class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + Lin and Cohen. From the abstract: + PIC finds a very low-dimensional embedding of a dataset using truncated power + iteration on a normalized pair-wise similarity matrix of the data. + + This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method + to run the PowerIterationClustering algorithm. + + .. seealso:: `Wikipedia on Spectral clustering \ + `_ + + >>> data = [(1, 0, 0.5), \ + (2, 0, 0.5), (2, 1, 0.7), \ + (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), \ + (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), \ + (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)] + >>> df = spark.createDataFrame(data).toDF("src", "dst", "weight") + >>> pic = PowerIterationClustering(k=2, maxIter=40, weightCol="weight") + >>> assignments = pic.assignClusters(df) + >>> assignments.sort(assignments.id).show(truncate=False) + +---+-------+ + |id |cluster| + +---+-------+ + |0 |1 | + |1 |1 | + |2 |1 | + |3 |1 | + |4 |1 | + |5 |0 | + +---+-------+ + ... + >>> pic_path = temp_path + "/pic" + >>> pic.save(pic_path) + >>> pic2 = PowerIterationClustering.load(pic_path) + >>> pic2.getK() + 2 + >>> pic2.getMaxIter() + 40 + + .. versionadded:: 2.4.0 + """ + + k = Param(Params._dummy(), "k", + "The number of clusters to create. Must be > 1.", + typeConverter=TypeConverters.toInt) + initMode = Param(Params._dummy(), "initMode", + "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use " + + "a normalized sum of similarities with other vertices. Supported options: " + + "'random' and 'degree'.", + typeConverter=TypeConverters.toString) + srcCol = Param(Params._dummy(), "srcCol", + "Name of the input column for source vertex IDs.", + typeConverter=TypeConverters.toString) + dstCol = Param(Params._dummy(), "dstCol", + "Name of the input column for destination vertex IDs.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", + weightCol=None): + """ + __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ + weightCol=None) + """ + super(PowerIterationClustering, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid) + self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", + weightCol=None): + """ + setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ + weightCol=None) + Sets params for PowerIterationClustering. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + return self._set(k=value) + + @since("2.4.0") + def getK(self): + """ + Gets the value of :py:attr:`k` or its default value. + """ + return self.getOrDefault(self.k) + + @since("2.4.0") + def setInitMode(self, value): + """ + Sets the value of :py:attr:`initMode`. + """ + return self._set(initMode=value) + + @since("2.4.0") + def getInitMode(self): + """ + Gets the value of :py:attr:`initMode` or its default value. + """ + return self.getOrDefault(self.initMode) + + @since("2.4.0") + def setSrcCol(self, value): + """ + Sets the value of :py:attr:`srcCol`. + """ + return self._set(srcCol=value) + + @since("2.4.0") + def getSrcCol(self): + """ + Gets the value of :py:attr:`srcCol` or its default value. + """ + return self.getOrDefault(self.srcCol) + + @since("2.4.0") + def setDstCol(self, value): + """ + Sets the value of :py:attr:`dstCol`. + """ + return self._set(dstCol=value) + + @since("2.4.0") + def getDstCol(self): + """ + Gets the value of :py:attr:`dstCol` or its default value. + """ + return self.getOrDefault(self.dstCol) + + @since("2.4.0") + def assignClusters(self, dataset): + """ + Run the PIC algorithm and returns a cluster assignment for each input vertex. + + :param dataset: + A dataset with columns src, dst, weight representing the affinity matrix, + which is the matrix A in the PIC paper. Suppose the src column value is i, + the dst column value is j, the weight column value is similarity s,,ij,, + which must be nonnegative. This is a symmetric matrix and hence + s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + ignored, because we assume s,,ij,, = 0.0. + + :return: + A dataset that contains columns of vertex id and the corresponding cluster for + the id. The schema of it will be: + - id: Long + - cluster: Int + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.assignClusters(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) + + if __name__ == "__main__": import doctest + import numpy import pyspark.ml.clustering from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cdda30cfab482..760aa82168f5a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1294,14 +1294,14 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, >>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345) >>> model = mh.fit(df) >>> model.transform(df).head() - Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([-1638925... + Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668... >>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),), ... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),), ... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)] >>> df2 = spark.createDataFrame(data2, ["id", "features"]) >>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0]) >>> model.approxNearestNeighbors(df2, key, 1).collect() - [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([-163892... + [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([6179668... >>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select( ... col("datasetA.id").alias("idA"), ... col("datasetB.id").alias("idB"), @@ -1309,8 +1309,8 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, +---+---+---------------+ |idA|idB|JaccardDistance| +---+---+---------------+ - | 1| 4| 0.5| | 0| 5| 0.5| + | 1| 4| 0.5| +---+---+---------------+ ... >>> mhPath = temp_path + "/mh" @@ -2582,25 +2582,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl typeConverter=TypeConverters.toListString) caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words", typeConverter=TypeConverters.toBoolean) + locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " + + "is true", typeConverter=TypeConverters.toString) @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, + locale=None): """ - __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ + locale=None) """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), - caseSensitive=False) + caseSensitive=False, locale=self._java_obj.getLocale()) kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, + locale=None): """ - setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ + locale=None) Sets params for this StopWordRemover. """ kwargs = self._input_kwargs @@ -2634,6 +2640,20 @@ def getCaseSensitive(self): """ return self.getOrDefault(self.caseSensitive) + @since("2.4.0") + def setLocale(self, value): + """ + Sets the value of :py:attr:`locale`. + """ + return self._set(locale=value) + + @since("2.4.0") + def getLocale(self): + """ + Gets the value of :py:attr:`locale`. + """ + return self.getOrDefault(self.locale) + @staticmethod @since("2.0.0") def loadDefaultStopWords(language): @@ -3823,12 +3843,12 @@ def setParams(self, inputCol=None, size=None, handleInvalid="error"): @since("2.3.0") def getSize(self): """ Gets size param, the size of vectors in `inputCol`.""" - self.getOrDefault(self.size) + return self.getOrDefault(self.size) @since("2.3.0") def setSize(self, value): """ Sets size param, the size of vectors in `inputCol`.""" - self._set(size=value) + return self._set(size=value) if __name__ == "__main__": diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index b8dafd49d354d..f9394421e0cc4 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,11 +16,12 @@ # from pyspark import keyword_only, since +from pyspark.sql import DataFrame from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * -__all__ = ["FPGrowth", "FPGrowthModel"] +__all__ = ["FPGrowth", "FPGrowthModel", "PrefixSpan"] class HasMinSupport(Params): @@ -243,3 +244,105 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", def _create_model(self, java_model): return FPGrowthModel(java_model) + + +class PrefixSpan(JavaParams): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + Efficiently by Prefix-Projected Pattern Growth + (see here). + This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` + method to run the PrefixSpan algorithm. + + @see Sequential Pattern Mining + (Wikipedia) + .. versionadded:: 2.4.0 + + """ + + minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.", + typeConverter=TypeConverters.toFloat) + + maxPatternLength = Param(Params._dummy(), "maxPatternLength", + "The maximal length of the sequential pattern. Must be > 0.", + typeConverter=TypeConverters.toInt) + + maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the " + + "internal storage format) allowed in a projected database before " + + "local processing. If a projected database exceeds this size, " + + "another iteration of distributed prefix growth is run. " + + "Must be > 0.", + typeConverter=TypeConverters.toInt) + + sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + super(PrefixSpan, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid) + self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def findFrequentSequentialPatterns(self, dataset): + """ + .. note:: Experimental + + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param dataset: A dataframe containing a sequence column which is + `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. + :return: A `DataFrame` that contains columns of sequence and corresponding frequency. + The schema of it will be: + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` + + >>> from pyspark.ml.fpm import PrefixSpan + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]), + ... Row(sequence=[[1], [3, 2], [1, 2]]), + ... Row(sequence=[[1, 2], [5]]), + ... Row(sequence=[[6]])]).toDF() + >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5) + >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False) + +----------+----+ + |sequence |freq| + +----------+----+ + |[[1]] |3 | + |[[1], [3]]|2 | + |[[1, 2]] |3 | + |[[2]] |3 | + |[[3]] |2 | + +----------+----+ + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 96d702f844839..5f0c57ee3cc67 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -31,6 +31,8 @@ from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string from pyspark.sql import DataFrame, SparkSession +__all__ = ["ImageSchema"] + class _ImageSchema(object): """ diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 6a611a2b5b59d..2548fd0f50b33 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1156,6 +1156,11 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values): def _test(): import doctest + try: + # Numpy 1.14+ changed it's string format. + np.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 6e9e0a34cdfde..e45ba840b412b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -162,7 +162,9 @@ def get$Name(self): "fitting. If set to true, then all sub-models will be available. Warning: For large " + "models, collecting all sub-models can cause OOMs on the Spark driver.", "False", "TypeConverters.toBoolean"), - ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] + ("loss", "the loss function to be optimized.", None, "TypeConverters.toString"), + ("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", + "'euclidean'", "TypeConverters.toString")] code = [] for name, doc, defaultValueStr, typeConverter in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 08408ee8fbfcc..618f5bf0a8103 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -790,3 +790,27 @@ def getCacheNodeIds(self): """ return self.getOrDefault(self.cacheNodeIds) + +class HasDistanceMeasure(Params): + """ + Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'. + """ + + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasDistanceMeasure, self).__init__() + self._setDefault(distanceMeasure='euclidean') + + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + def getDistanceMeasure(self): + """ + Gets the value of distanceMeasure or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9a66d87d7f211..513ca5a9df85e 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction True >>> model.numFeatures 1 + >>> model.write().format("pmml").save(model_path + "_2") .. versionadded:: 1.4.0 """ @@ -161,7 +162,7 @@ def getEpsilon(self): return self.getOrDefault(self.epsilon) -class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable): """ Model fitted by :class:`LinearRegression`. @@ -602,6 +603,19 @@ class TreeEnsembleParams(DecisionTreeParams): "used for learning each decision tree, in range (0, 1].", typeConverter=TypeConverters.toFloat) + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: 'auto' (choose automatically for task: If numTrees == 1, set to " + + "'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " + + "'onethird' for regression), 'all' (use all features), 'onethird' (use " + + "1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " + + "log2(number of features)), 'n' (when n is in the range (0, 1.0], use " + + "n * number of features. When n is in the range (1, number of features), use" + + " n features). default = 'auto'", typeConverter=TypeConverters.toString) + def __init__(self): super(TreeEnsembleParams, self).__init__() @@ -619,6 +633,22 @@ def getSubsamplingRate(self): """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + + .. note:: Deprecated in 2.4.0 and will be removed in 3.0.0. + """ + return self._set(featureSubsetStrategy=value) + + @since("1.4.0") + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + class TreeRegressorParams(Params): """ @@ -654,14 +684,8 @@ class RandomForestParams(TreeEnsembleParams): Private class to track supported random forest parameters. """ - supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).", typeConverter=TypeConverters.toInt) - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", - typeConverter=TypeConverters.toString) def __init__(self): super(RandomForestParams, self).__init__() @@ -680,20 +704,6 @@ def getNumTrees(self): """ return self.getOrDefault(self.numTrees) - @since("1.4.0") - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - return self._set(featureSubsetStrategy=value) - - @since("1.4.0") - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class GBTParams(TreeEnsembleParams): """ @@ -981,6 +991,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): @@ -1029,6 +1046,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> print(gbt.getImpurity()) variance + >>> print(gbt.getFeatureSubsetStrategy()) + all >>> model = gbt.fit(df) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1056,6 +1075,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))], + ... ["label", "features"]) + >>> model.evaluateEachIteration(validation, "squared") + [0.0, 0.0, 0.0, 0.0, 0.0] .. versionadded:: 1.4.0 """ @@ -1075,20 +1098,20 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impurity="variance"): + impurity="variance", featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, - impurity="variance") + impurity="variance", featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1098,13 +1121,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impuriy="variance"): + impuriy="variance", featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Regression. """ kwargs = self._input_kwargs @@ -1127,6 +1150,13 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ @@ -1156,6 +1186,20 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset, loss): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + :param loss: + The loss function used to compute error. + Supported options: squared, absolute + """ + return self._call_java("evaluateEachIteration", dataset, loss) + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, @@ -1331,7 +1375,7 @@ def intercept(self): @since("1.6.0") def scale(self): """ - Model scale paramter. + Model scale parameter. """ return self._call_java("scale") diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index a06ab31a7a56a..370154fc6d62a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -388,8 +388,14 @@ def summary(self, featuresCol, weightCol=None): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.stat from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.stat.__dict__.copy() # The small batch size here ensures that we see multiple batches, diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2ec0be60e9fa9..5c87d1de4139b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -681,6 +681,13 @@ def test_stopwordsremover(self): self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BELKİ"] + dataset = self.spark.createDataFrame([Row(input=["belki"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) def test_count_vectorizer_with_binary(self): dataset = self.spark.createDataFrame([ @@ -837,6 +844,23 @@ def test_string_indexer_from_labels(self): .select(model_default.getOrDefault(model_default.outputCol)).collect() self.assertEqual(len(transformed_list), 5) + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + class HasInducedError(Params): @@ -943,6 +967,13 @@ def test_fit_maximize_metric(self): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_param_grid_type_coercion(self): + lr = LogisticRegression(maxIter=10) + paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() + for param in paramGrid: + for v in param.values(): + assert(type(v) == float) + def test_save_load_trained_model(self): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 @@ -1355,6 +1386,23 @@ def test_linear_regression(self): except OSError: pass + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + def test_logistic_regression(self): lr = LogisticRegression(maxIter=1) path = tempfile.mkdtemp() @@ -1595,6 +1643,44 @@ def test_default_read_write(self): self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + class LDATest(SparkSessionTestCase): @@ -1826,6 +1912,7 @@ def test_gaussian_mixture_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 3) def test_bisecting_kmeans_summary(self): data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), @@ -1841,6 +1928,7 @@ def test_bisecting_kmeans_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 20) def test_kmeans_summary(self): data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), @@ -1856,6 +1944,7 @@ def test_kmeans_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 1) class KMeansTests(SparkSessionTestCase): @@ -2136,17 +2225,23 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True # Note that here we enable Hive's support. cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") - cls.spark = HiveContext._createForTesting(cls.sc) + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -2662,6 +2757,6 @@ def testDefaultFitMultiple(self): if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0c8029f293cfe..1f4abf5157335 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -115,7 +115,11 @@ def build(self): """ keys = self._param_grid.keys() grid_values = self._param_grid.values() - return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] + + def to_key_value_pairs(keys, values): + return [(key, key.typeConverter(value)) for key, value in zip(keys, values)] + + return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)] class ValidatorParams(HasSeed): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a486c6a3fdeb5..e846834761e49 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,6 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.util import VersionUtils def _jvm(): @@ -62,7 +63,7 @@ def _randomUID(cls): Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:]) @inherit_doc @@ -147,6 +148,23 @@ def overwrite(self): return self +@inherit_doc +class GeneralMLWriter(MLWriter): + """ + Utility class that can save ML instances in different formats. + + .. versionadded:: 2.4.0 + """ + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self.source = source + return self + + @inherit_doc class JavaMLWriter(MLWriter): """ @@ -191,6 +209,24 @@ def session(self, sparkSession): return self +@inherit_doc +class GeneralJavaMLWriter(JavaMLWriter): + """ + (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types + """ + + def __init__(self, instance): + super(GeneralJavaMLWriter, self).__init__(instance) + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self._jwrite.format(source) + return self + + @inherit_doc class MLWritable(object): """ @@ -219,6 +255,17 @@ def write(self): return JavaMLWriter(self) +@inherit_doc +class GeneralJavaMLWritable(JavaMLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. + """ + + def write(self): + """Returns an GeneralMLWriter instance for this ML instance.""" + return GeneralJavaMLWriter(self) + + @inherit_doc class MLReader(BaseReadWrite): """ @@ -396,6 +443,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap + - defaultParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field. @@ -417,15 +465,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): """ uid = instance.uid cls = instance.__module__ + '.' + instance.__class__.__name__ - params = instance.extractParamMap() + + # User-supplied param values + params = instance._paramMap jsonParams = {} if paramMap is not None: jsonParams = paramMap else: for p in params: jsonParams[p.name] = params[p] + + # Default param values + jsonDefaultParams = {} + for p in instance._defaultParamMap: + jsonDefaultParams[p.name] = instance._defaultParamMap[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), - "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, + "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) @@ -523,11 +580,26 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ + # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) + # Set default param values + majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) + major = majorAndMinorVersions[0] + minor = majorAndMinorVersions[1] + + # For metadata file prior to Spark 2.4, there is no default section. + if major > 2 or (major == 2 and minor >= 4): + assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ + "`defaultParamMap` section not found" + + for paramName in metadata['defaultParamMap']: + paramValue = metadata['defaultParamMap'][paramName] + instance._setDefault(**{paramName: paramValue}) + @staticmethod def loadParamsInstance(path, sc): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index bb281981fd56b..e00ed95ef0701 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -258,6 +258,9 @@ def load(cls, sc, path): model.setThreshold(threshold) return model + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionWithSGD(object): """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 0cbabab13a896..b09469b9f5c2d 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1042,7 +1042,13 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, def _test(): import doctest + import numpy import pyspark.mllib.clustering + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 36cb03369b8c0..6c65da58e4e2b 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -532,8 +532,14 @@ def accuracy(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession import pyspark.mllib.evaluation + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.evaluation.__dict__.copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 60d96d8d5ceb8..4afd6666400b0 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1368,6 +1368,12 @@ def R(self): def _test(): import doctest + import numpy + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index bba88542167ad..7e8b15056cabe 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1364,9 +1364,15 @@ def toCoordinateMatrix(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.linalg.distributed.__dict__.copy() spark = SparkSession.builder\ .master("local[2]")\ diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 3c75b132ecad2..6e89bfd691d16 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -259,7 +259,7 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): The KS statistic gives us the maximum distance between the ECDF and the CDF. Intuitively if this statistic is large, the - probabilty that the null hypothesis is true becomes small. + probability that the null hypothesis is true becomes small. For specific details of the implementation, please have a look at the Scala documentation. @@ -303,7 +303,13 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): def _test(): import doctest + import numpy from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = globals().copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1037bab7f1088..4c2ce137e331c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -57,6 +57,7 @@ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs @@ -1762,14 +1763,25 @@ def test_pca(self): self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b44f76747264..380475e706fbe 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,11 @@ else: from itertools import imap as map, ifilter as filter +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer + PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ + UTF8Deserializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -51,6 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration, _exception_message __all__ = ["RDD"] @@ -71,6 +74,7 @@ class PythonEvalType(object): SQL_SCALAR_PANDAS_UDF = 200 SQL_GROUPED_MAP_PANDAS_UDF = 201 SQL_GROUPED_AGG_PANDAS_UDF = 202 + SQL_WINDOW_AGG_PANDAS_UDF = 203 def portable_hash(x): @@ -136,28 +140,13 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) -def _load_from_socket(port, serializer): - sock = None - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - sock.settimeout(15) - sock.connect(sa) - except socket.error: - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket") +def _load_from_socket(sock_info, serializer): + (sockfile, sock) = local_connect_and_auth(*sock_info) # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) # The socket will be automatically closed when garbage-collected. - return serializer.load_stream(sock.makefile("rb", 65536)) + return serializer.load_stream(sockfile) def ignore_unicode_prefix(f): @@ -332,7 +321,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -347,7 +336,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -410,7 +399,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -791,6 +780,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -822,8 +813,8 @@ def collect(self): to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) - return list(_load_from_socket(port, self._jrdd_deserializer)) + sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(sock_info, self._jrdd_deserializer)) def reduce(self, f): """ @@ -840,6 +831,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -911,6 +904,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -943,6 +938,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1342,7 +1340,7 @@ def take(self, num): if len(items) == 0: numPartsToTry = partsScanned * 4 else: - # the first paramter of max is >=1 whenever partsScanned >= 2 + # the first parameter of max is >=1 whenever partsScanned >= 2 numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4) @@ -1352,7 +1350,10 @@ def takeUpToNumLeft(iterator): iterator = iter(iterator) taken = 0 while taken < left: - yield next(iterator) + try: + yield next(iterator) + except StopIteration: + return taken += 1 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) @@ -1636,6 +1637,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: @@ -2380,8 +2383,24 @@ def toLocalIterator(self): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(port, self._jrdd_deserializer) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(sock_info, self._jrdd_deserializer) + + def barrier(self): + """ + .. note:: Experimental + + Indicates that Spark must launch the tasks together for the current stage. + + .. versionadded:: 2.4.0 + """ + return RDDBarrier(self) + + def _is_barrier(self): + """ + Whether this RDD is in a barrier stage. + """ + return self._jrdd.rdd().isBarrier() def _prepare_for_python_RDD(sc, command): @@ -2406,6 +2425,33 @@ def _wrap_function(sc, func, deserializer, serializer, profiler=None): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class RDDBarrier(object): + + """ + .. note:: Experimental + + An RDDBarrier turns an RDD into a barrier RDD, which forces Spark to launch tasks of the stage + contains this RDD together. + + .. versionadded:: 2.4.0 + """ + + def __init__(self, rdd): + self.rdd = rdd + + def mapPartitions(self, f, preservesPartitioning=False): + """ + .. note:: Experimental + + Return a new RDD by applying a function to each partition of this RDD. + + .. versionadded:: 2.4.0 + """ + def func(s, iterator): + return f(iterator) + return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True) + + class PipelinedRDD(RDD): """ @@ -2425,7 +2471,7 @@ class PipelinedRDD(RDD): 20 """ - def __init__(self, prev, func, preservesPartitioning=False): + def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False): if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): # This transformation is the first in its stage: self.func = func @@ -2451,6 +2497,7 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None + self.is_barrier = prev._is_barrier() or isFromBarrier def getNumPartitions(self): return self._prev_jrdd.partitions().size() @@ -2470,7 +2517,7 @@ def _jrdd(self): wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer, profiler) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, - self.preservesPartitioning) + self.preservesPartitioning, self.is_barrier) self._jrdd_val = python_rdd.asJavaRDD() if profiler: @@ -2486,6 +2533,9 @@ def id(self): def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) + def _is_barrier(self): + return self.is_barrier + def _test(): import doctest diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd903..48006778e86f2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,8 +33,9 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -PySpark serialize objects in batches; By default, the batch size is chosen based -on the size of objects, also configurable by SparkContext's C{batchSize} parameter: +PySpark serializes objects in batches; by default, the batch size is chosen based +on the size of objects and is also configurable by SparkContext's C{batchSize} +parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -100,7 +101,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ Return an iterator of deserialized batches (iterable) of objects from the input stream. - if the serializer does not operate on batches the default implementation returns an + If the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ return map(lambda x: [x], self.load_stream(stream)) @@ -184,27 +185,31 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowStreamSerializer(Serializer): """ - Serializes bytes as Arrow data with the Arrow file format. + Serializes Arrow record batches as a stream. """ - def dumps(self, batch): + def dump_stream(self, iterator, stream): import pyarrow as pa - import io - sink = io.BytesIO() - writer = pa.RecordBatchFileWriter(sink, batch.schema) - writer.write_batch(batch) - writer.close() - return sink.getvalue() + writer = None + try: + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() - def loads(self, obj): + def load_stream(self, stream): import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - return reader.read_all() + reader = pa.open_stream(stream) + for batch in reader: + yield batch def __repr__(self): - return "ArrowSerializer" + return "ArrowStreamSerializer" def _create_batch(series, timezone): @@ -215,9 +220,10 @@ def _create_batch(series, timezone): :param timezone: A timezone to respect when handling timestamp values :return: Arrow RecordBatch """ - - from pyspark.sql.types import _check_series_convert_timestamps_internal + import decimal + from distutils.version import LooseVersion import pyarrow as pa + from pyspark.sql.types import _check_series_convert_timestamps_internal # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ (len(series) == 2 and isinstance(series[1], pa.DataType)): @@ -227,14 +233,21 @@ def _create_batch(series, timezone): def create_array(s, t): mask = s.isnull() # Ensure timestamp series are in expected form for Spark internal representation + # TODO: maybe don't need None check anymore as of Arrow 0.9.1 if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) elif t is not None and pa.types.is_string(t) and sys.version < '3': # TODO: need decode before converting to Arrow in Python 2 + # TODO: don't need as of Arrow 0.9.1 return pa.Array.from_pandas(s.apply( lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) + elif t is not None and pa.types.is_decimal(t) and \ + LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. + return pa.Array.from_pandas(s.apply( + lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] @@ -461,7 +474,7 @@ def dumps(self, obj): return obj -# Hook namedtuple, make it picklable +# Hack namedtuple, make it picklable __cls = {} @@ -525,15 +538,15 @@ def namedtuple(*args, **kwargs): cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) - # replace namedtuple with new one + # replace namedtuple with the new one collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 - # hack the cls already generated by namedtuple - # those created in other module can be pickled as normal, + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple @@ -627,7 +640,7 @@ def loads(self, obj): elif _type == b'P': return pickle.loads(obj[1:]) else: - raise ValueError("invalid sevialization type: %s" % _type) + raise ValueError("invalid serialization type: %s" % _type) class CompressedSerializer(FramedSerializer): @@ -706,6 +719,13 @@ def write_int(value, stream): stream.write(struct.pack("!i", value)) +def read_bool(stream): + length = stream.read(1) + if not length: + raise EOFError + return struct.unpack("!?", length)[0] + + def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index b5fcf7092d93a..472c3cd4452f0 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -38,25 +38,13 @@ SparkContext._ensure_initialized() try: - # Try to access HiveConf, it will raise exception if Hive is not added - conf = SparkConf() - if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() - else: - spark = SparkSession.builder.getOrCreate() -except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() -except TypeError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() + spark = SparkSession._create_shell_session() +except Exception: + import sys + import traceback + warnings.warn("Failed to initialize Spark session.") + traceback.print_exc(file=sys.stderr) + sys.exit(1) sc = spark.sparkContext sql = spark.sql diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -94,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index b0d8357f4feec..974251f63b37a 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -177,8 +177,7 @@ def createTable(self, tableName, path=None, source=None, schema=None, **options) if path is not None: options["path"] = path if source is None: - source = self._sparkSession.conf.get( - "spark.sql.sources.default", "org.apache.spark.sql.parquet") + source = self._sparkSession._wrapped._conf.defaultDataSourceName() if schema is None: df = self._jcatalog.createTable(tableName, source, options) else: diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index db49040e17b63..71ea1631718f1 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -20,6 +20,9 @@ from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix +if sys.version_info[0] >= 3: + basestring = str + class RuntimeConfig(object): """User-facing configuration API, accessible through `SparkSession.conf`. @@ -59,10 +62,18 @@ def unset(self, key): def _checkType(self, obj, identifier): """Assert that an object is of type str.""" - if not isinstance(obj, str) and not isinstance(obj, unicode): + if not isinstance(obj, basestring): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) + @ignore_unicode_prefix + @since(2.4) + def isModifiable(self, key): + """Indicates whether the configuration property with the given key + is modifiable in the current session. + """ + return self._jconf.isModifiable(key) + def _test(): import os diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e9ec7ba866761..9c094dd9a9033 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -93,6 +93,11 @@ def _ssql_ctx(self): """ return self._jsqlContext + @property + def _conf(self): + """Accessor for the JVM SQL-specific configurations""" + return self.sparkSession._jsparkSession.sessionState().conf() + @classmethod @since(1.6) def getOrCreate(cls, sc): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 16f8e52dead7b..1affc9b4fcf6c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,7 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ +from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -78,6 +78,9 @@ def __init__(self, jdf, sql_ctx): self.is_cached = False self._schema = None # initialized lazily self._lazy_rdd = None + # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice + # by __repr__ and _repr_html_ while eager evaluation opened. + self._support_repr_html = False @property @since(1.3) @@ -290,6 +293,31 @@ def explain(self, extended=False): else: print(self._jdf.queryExecution().simpleString()) + @since(2.4) + def exceptAll(self, other): + """Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but + not in another :class:`DataFrame` while preserving duplicates. + + This is equivalent to `EXCEPT ALL` in SQL. + + >>> df1 = spark.createDataFrame( + ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.exceptAll(df2).show() + +---+---+ + | C1| C2| + +---+---+ + | a| 1| + | a| 1| + | a| 2| + | c| 4| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx) + @since(1.3) def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally @@ -352,7 +380,46 @@ def show(self, n=20, truncate=True, vertical=False): print(self._jdf.showString(n, int(truncate), vertical)) def __repr__(self): - return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + if not self._support_repr_html and self.sql_ctx._conf.isReplEagerEvalEnabled(): + vertical = False + return self._jdf.showString( + self.sql_ctx._conf.replEagerEvalMaxNumRows(), + self.sql_ctx._conf.replEagerEvalTruncate(), vertical) + else: + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + def _repr_html_(self): + """Returns a dataframe with html code when you enabled eager evaluation + by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are + using support eager evaluation with HTML. + """ + import cgi + if not self._support_repr_html: + self._support_repr_html = True + if self.sql_ctx._conf.isReplEagerEvalEnabled(): + max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0) + sock_info = self._jdf.getRowsToPython( + max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate()) + rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + head = rows[0] + row_data = rows[1:] + has_more_data = len(row_data) > max_num_rows + row_data = row_data[:max_num_rows] + + html = "\n" + # generate table head + html += "\n" % "\n" % "
      %s
      ".join(map(lambda x: cgi.escape(x), head)) + # generate table rows + for row in row_data: + html += "
      %s
      ".join( + map(lambda x: cgi.escape(x), row)) + html += "
      \n" + if has_more_data: + html += "only showing top %d %s\n" % ( + max_num_rows, "row" if max_num_rows == 1 else "rows") + return html + else: + return None @since(2.1) def checkpoint(self, eager=True): @@ -463,8 +530,8 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + sock_info = self._jdf.collectToPython() + return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(2.0) @@ -477,8 +544,8 @@ def toLocalIterator(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.toPythonIterator() - return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + sock_info = self._jdf.toPythonIterator() + return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) @@ -1433,6 +1500,28 @@ def intersect(self, other): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + @since(2.4) + def intersectAll(self, other): + """ Return a new :class:`DataFrame` containing rows in both this dataframe and other + dataframe while preserving duplicates. + + This is equivalent to `INTERSECT ALL` in SQL. + >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.intersectAll(df2).sort("C1", "C2").show() + +---+---+ + | C1| C2| + +---+---+ + | a| 1| + | a| 1| + | b| 3| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx) + @since(1.3) def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame @@ -1975,6 +2064,8 @@ def toPandas(self): .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice @@ -1985,13 +2076,12 @@ def toPandas(self): import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.sql_ctx.getConf("spark.sql.session.timeZone") + if self.sql_ctx._conf.pandasRespectSessionTimeZone(): + timezone = self.sql_ctx._conf.sessionLocalTimeZone() else: timezone = None - if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + if self.sql_ctx._conf.arrowEnabled(): use_arrow = True try: from pyspark.sql.types import to_arrow_schema @@ -2001,8 +2091,7 @@ def toPandas(self): to_arrow_schema(self.schema) except Exception as e: - if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self.sql_ctx._conf.arrowFallbackEnabled(): msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " @@ -2029,10 +2118,9 @@ def toPandas(self): from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps import pyarrow - - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) + batches = self._collectAsArrow() + if len(batches) > 0: + table = pyarrow.Table.from_batches(batches) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2081,14 +2169,14 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. + Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + and available on driver and worker Python environments. .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) + sock_info = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(sock_info, ArrowStreamSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index da32ab25cad0c..d58d8d10e5cd3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -152,6 +152,9 @@ def _(): _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_list('age')).collect() [Row(collect_list(age)=[2, 5, 5])] @@ -159,6 +162,9 @@ def _(): _collect_set_doc = """ Aggregate function: returns a set of objects with duplicate elements eliminated. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_set('age')).collect() [Row(collect_set(age)=[5, 2])] @@ -401,6 +407,9 @@ def first(col, ignorenulls=False): The function by default returns the first values it sees. It will return the first non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows which + may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) @@ -489,6 +498,9 @@ def last(col, ignorenulls=False): The function by default returns the last values it sees. It will return the last non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows + which may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) @@ -504,6 +516,8 @@ def monotonically_increasing_id(): within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + .. note:: The function is non-deterministic because its result depends on partition IDs. + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -536,6 +550,8 @@ def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples from U[0.0, 1.0]. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('rand', rand(seed=42) * 3).collect() [Row(age=2, name=u'Alice', rand=1.1568609015300986), Row(age=5, name=u'Bob', rand=1.403379671529166)] @@ -554,6 +570,8 @@ def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('randn', randn(seed=42)).collect() [Row(age=2, name=u'Alice', randn=-0.7556247885860078), Row(age=5, name=u'Bob', randn=-0.0861619008451133)] @@ -1088,16 +1106,23 @@ def add_months(start, months): @since(1.5) -def months_between(date1, date2): +def months_between(date1, date2, roundOff=True): """ - Returns the number of months between date1 and date2. + Returns number of months between dates date1 and date2. + If date1 is later than date2, then the result is positive. + If date1 and date2 are on the same day of month, or both are the last day of month, + returns an integer (time of day will be ignored). + The result is rounded off to 8 digits unless `roundOff` is set to `False`. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() - [Row(months=3.9495967...)] + [Row(months=3.94959677)] + >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() + [Row(months=3.9495967741935485)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + return Column(sc._jvm.functions.months_between( + _to_java_column(date1), _to_java_column(date2), roundOff)) @since(2.2) @@ -1260,11 +1285,21 @@ def from_utc_timestamp(timestamp, tz): that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 03:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) + >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect() [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1275,11 +1310,21 @@ def to_utc_timestamp(timestamp, tz): zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 01:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1794,6 +1839,25 @@ def create_map(*cols): return Column(jc) +@since(2.4) +def map_from_arrays(col1, col2): + """Creates a new map from two arrays. + + :param col1: name of column containing a set of keys. All elements should not be null + :param col2: name of column containing a set of values + + >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v']) + >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show() + +----------------+ + | map| + +----------------+ + |[2 -> a, 5 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def array(*cols): """Creates a new array column. @@ -1830,6 +1894,55 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def arrays_overlap(a1, a2): + """ + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. + + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + + +@since(2.4) +def slice(x, start, length): + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (or starting from the end if `start` is negative) with the specified `length`. + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + + +@ignore_unicode_prefix +@since(2.4) +def array_join(col, delimiter, null_replacement=None): + """ + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')] + """ + sc = SparkContext._active_spark_context + if null_replacement is None: + return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter)) + else: + return Column(sc._jvm.functions.array_join( + _to_java_column(col), delimiter, null_replacement)) + + @since(1.5) @ignore_unicode_prefix def concat(*cols): @@ -1890,6 +2003,93 @@ def element_at(col, extraction): return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) +@since(2.4) +def array_remove(col, element): + """ + Collection function: Remove all elements that equal to element from the given array. + + :param col: name of column containing array + :param element: element to be removed from the array + + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df.select(array_remove(df.data, 1)).collect() + [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) + + +@since(2.4) +def array_distinct(col): + """ + Collection function: removes duplicate values from the array. + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(2.4) +def array_intersect(col1, col2): + """ + Collection function: returns an array of the elements in the intersection of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_intersect(df.c1, df.c2)).collect() + [Row(array_intersect(c1, c2)=[u'a', u'c'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) + + +@ignore_unicode_prefix +@since(2.4) +def array_union(col1, col2): + """ + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) + + +@ignore_unicode_prefix +@since(2.4) +def array_except(col1, col2): + """ + Collection function: returns an array of the elements in col1 but not in col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_except(df.c1, df.c2)).collect() + [Row(array_except(c1, c2)=[u'b'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. @@ -2036,12 +2236,13 @@ def json_tuple(col, *fields): return Column(jc) +@ignore_unicode_prefix @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType` - of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an - unparseable string. + Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` + as keys type, :class:`StructType` or :class:`ArrayType` with + the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format :param schema: a StructType or ArrayType of StructType to use when parsing the json column. @@ -2058,16 +2259,28 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] + >>> df.select(from_json(df.value, "MAP").alias("json")).collect() + [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=[Row(a=1)])] + >>> schema = schema_of_json(lit('''{"a": 0}''')) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] + >>> data = [(1, '''[1, 2, 3]''')] + >>> schema = ArrayType(IntegerType()) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=[1, 2, 3])] """ sc = SparkContext._active_spark_context if isinstance(schema, DataType): schema = schema.json() + elif isinstance(schema, Column): + schema = _to_java_column(schema) jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) @@ -2109,6 +2322,28 @@ def to_json(col, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.4) +def schema_of_json(col): + """ + Parses a column containing a JSON string and infers its schema in DDL format. + + :param col: string column in json format + + >>> from pyspark.sql.types import * + >>> data = [(1, '{"a": 1}')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_json(df.value).alias("json")).collect() + [Row(json=u'struct')] + >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() + [Row(json=u'struct')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_json(_to_java_column(col)) + return Column(jc) + + @since(1.5) def size(col): """ @@ -2158,20 +2393,55 @@ def array_max(col): def sort_array(col, asc=True): """ Collection function: sorts the input array in ascending or descending order according - to the natural ordering of the array elements. + to the natural ordering of the array elements. Null elements will be placed at the beginning + of the returned array in ascending order or at the end of the returned array in descending + order. :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() - [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() - [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.4) +def array_sort(col): + """ + Collection function: sorts the input array in ascending order. The elements of the input array + must be orderable. Null elements will be placed at the end of the returned array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) + >>> df.select(array_sort(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_sort(_to_java_column(col))) + + +@since(2.4) +def shuffle(col): + """ + Collection function: Generates a random permutation of the given array. + + .. note:: The function is non-deterministic. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data']) + >>> df.select(shuffle(df.data).alias('s')).collect() # doctest: +SKIP + [Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.shuffle(_to_java_column(col))) + + @since(1.5) @ignore_unicode_prefix def reverse(col): @@ -2191,6 +2461,23 @@ def reverse(col): return Column(sc._jvm.functions.reverse(_to_java_column(col))) +@since(2.4) +def flatten(col): + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df.select(flatten(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.flatten(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ @@ -2231,6 +2518,121 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@since(2.4) +def map_entries(col): + """ + Collection function: Returns an unordered array of all entries in the given map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_entries + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_entries("data").alias("entries")).show() + +----------------+ + | entries| + +----------------+ + |[[1, a], [2, b]]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_entries(_to_java_column(col))) + + +@since(2.4) +def map_from_entries(col): + """ + Collection function: Returns a map created from the given array of entries. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |[1 -> a, 2 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(2.4) +def array_repeat(col, count): + """ + Collection function: creates an array containing a column repeated count times. + + >>> df = spark.createDataFrame([('ab',)], ['data']) + >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + [Row(r=[u'ab', u'ab', u'ab'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) + + +@since(2.4) +def arrays_zip(*cols): + """ + Collection function: Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + + :param cols: columns of arrays to be merged. + + >>> from pyspark.sql.functions import arrays_zip + >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2']) + >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect() + [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) + + +@since(2.4) +def map_concat(*cols): + """Returns the union of all the given maps. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +--------------------------------+ + |map3 | + +--------------------------------+ + |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| + +--------------------------------+ + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(2.4) +def sequence(start, stop, step=None): + """ + Generate a sequence of integers from `start` to `stop`, incrementing by `step`. + If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`, + otherwise -1. + + >>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) + >>> df1.select(sequence('C1', 'C2').alias('r')).collect() + [Row(r=[-2, -1, 0, 1, 2])] + >>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) + >>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect() + [Row(r=[4, 2, 0, -2, -4])] + """ + sc = SparkContext._active_spark_context + if step is None: + return Column(sc._jvm.functions.sequence(_to_java_column(start), _to_java_column(stop))) + else: + return Column(sc._jvm.functions.sequence( + _to_java_column(start), _to_java_column(stop), _to_java_column(step))) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): @@ -2309,6 +2711,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. + .. note:: Experimental + The function type of the UDF can be one of the following: 1. SCALAR @@ -2350,7 +2754,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined returnType schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. The length of the returned `pandas.DataFrame` can be arbitrary. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. @@ -2399,6 +2805,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2|6.0| +---+---+ + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG @@ -2408,10 +2820,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as - output types. + :class:`MapType` and :class:`StructType` are currently not supported as output types. - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and + :class:`pyspark.sql.Window` + + This example shows using grouped aggregated UDFs with groupby: >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( @@ -2428,7 +2842,32 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + This example shows using grouped aggregated UDFs as window functions. Note that only + unbounded window frame is supported at the moment: + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> w = Window \\ + ... .partitionBy('id') \\ + ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.5| + | 1| 2.0| 1.5| + | 2| 3.0| 6.0| + | 2| 5.0| 6.0| + | 2|10.0| 6.0| + +---+----+------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked @@ -2492,6 +2931,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): blacklist = ['map', 'since', 'ignore_unicode_prefix'] __all__ = [k for k, v in globals().items() if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] +__all__ += ["PandasUDFType"] __all__.sort() diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 3505065b648f2..cc1da8e7c1f72 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -211,6 +211,8 @@ def pivot(self, pivot_col, values=None): >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + >>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect() + [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: jgd = self._jgd.pivot(pivot_col) @@ -236,6 +238,8 @@ def apply(self, udf): into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. + .. note:: Experimental + :param udf: a grouped map user-defined function returned by :func:`pyspark.sql.functions.pandas_udf`. @@ -294,6 +298,12 @@ def _test(): Row(course="dotNET", year=2012, earnings=5000), Row(course="dotNET", year=2013, earnings=48000), Row(course="Java", year=2013, earnings=30000)]).toDF() + globs['df5'] = sc.parallelize([ + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)), + Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)), + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)), + Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)), + Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bd79bc2f43e5..49f4e6b2ede1b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, + dropFieldIfAllNull=None, encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,8 +238,17 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param samplingRatio: defines fraction of input JSON objects used for schema inferring. + If None is set, it uses the default value, ``1.0``. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -256,7 +266,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, + samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -337,7 +348,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + samplingRatio=None, enforceSchema=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -364,6 +376,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -420,6 +442,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise. + :param samplingRatio: defines fraction of rows used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -438,7 +462,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, + enforceSchema=enforceSchema) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -749,7 +774,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - lineSep=None): + lineSep=None, encoding=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -773,6 +798,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param encoding: specifies encoding (charset) of saved json files. If None is set, + the default UTF-8 charset will be used. :param lineSep: defines the line separator that should be used for writing. If None is set, it uses the default value, ``\\n``. @@ -781,7 +808,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm self.mode(mode) self._set_opts( compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - lineSep=lineSep) + lineSep=lineSep, encoding=encoding) self._jwrite.json(path) @since(1.4) @@ -798,10 +825,10 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): exists. :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the - known case-insensitive shorten names (none, snappy, gzip, and lzo). - This will override ``spark.sql.parquet.compression.codec``. If None - is set, it uses the value specified in - ``spark.sql.parquet.compression.codec``. + known case-insensitive shorten names (none, uncompressed, snappy, gzip, + lzo, brotli, lz4, and zstd). This will override + ``spark.sql.parquet.compression.codec``. If None is set, it uses the + value specified in ``spark.sql.parquet.compression.codec``. >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -832,7 +859,7 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None): + charToEscapeQuoteEscaping=None, encoding=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -882,6 +909,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise.. + :param encoding: sets the encoding (charset) of saved csv files. If None is set, + the default UTF-8 charset will be used. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -891,7 +920,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No dateFormat=dateFormat, timestampFormat=timestampFormat, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, + encoding=encoding) self._jwrite.csv(path) @since(1.5) @@ -966,7 +996,7 @@ def _test(): globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') try: - spark = SparkSession.builder.enableHiveSupport().getOrCreate() + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 13d6e2e53dbd0..87d8d6a59a6e9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.serializers import ArrowStreamSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,14 +539,43 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): struct.names[i] = name schema = struct - # Create the Spark DataFrame directly from the Arrow data and schema - jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) - jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( - jrdd, schema.json(), self._wrapped._jsqlContext) + def reader_func(temp_filename): + return self._jvm.PythonSQLUtils.arrowReadStreamFromFile( + self._wrapped._jsqlContext, temp_filename, schema.json()) + + # Create Spark DataFrame from Arrow stream file, using one batch per partition + jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func) df = DataFrame(jdf, self._wrapped) df._schema = schema return df + @staticmethod + def _create_shell_session(): + """ + Initialize a SparkSession for a pyspark shell session. This is called from shell.py + to make error handling simpler without needing to declare local variables in that + script, which would expose those to users. + """ + import py4j + from pyspark.conf import SparkConf + from pyspark.context import SparkContext + try: + # Try to access HiveConf, it will raise exception if Hive is not added + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + return SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + return SparkSession.builder.getOrCreate() + except (py4j.protocol.Py4JError, TypeError): + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") + + return SparkSession.builder.getOrCreate() + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): @@ -584,6 +613,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr .. versionchanged:: 2.1 Added verifySchema. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] @@ -649,9 +680,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.conf.get("spark.sql.session.timeZone") + if self._wrapped._conf.pandasRespectSessionTimeZone(): + timezone = self._wrapped._conf.sessionLocalTimeZone() else: timezone = None @@ -661,15 +691,13 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr (x.encode('utf-8') if not isinstance(x, str) else x) for x in data.columns] - if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ - and len(data) > 0: + if self._wrapped._conf.arrowEnabled() and len(data) > 0: try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: from pyspark.util import _exception_message - if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self._wrapped._conf.arrowFallbackEnabled(): msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 15f9407389864..ee13778a7dcd6 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -19,17 +19,16 @@ import json if sys.version >= '3': - intlike = int - basestring = unicode = str -else: - intlike = (int, long) + basestring = str + +from py4j.java_gateway import java_import from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * -from pyspark.sql.utils import StreamingQueryException +from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -564,7 +563,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + enforceSchema=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -592,6 +592,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -664,7 +674,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: @@ -843,6 +853,197 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + @since(2.4) + def foreach(self, f): + """ + Sets the output of the streaming query to be processed using the provided writer ``f``. + This is often used to write the output of a streaming query to arbitrary storage systems. + The processing logic can be specified in two ways. + + #. A **function** that takes a row as input. + This is a simple way to express your processing logic. Note that this does + not allow you to deduplicate generated data when failures cause reprocessing of + some input data. That would require you to specify the processing logic in the next + way. + + #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. + The object can have the following methods. + + * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing + (for example, open a connection, start a transaction, etc). Additionally, you can + use the `partition_id` and `epoch_id` to deduplicate regenerated data + (discussed later). + + * ``process(row)``: *Non-optional* method that processes each :class:`Row`. + + * ``close(error)``: *Optional* method that finalizes and cleans up (for example, + close connection, commit transaction, etc.) after all rows have been processed. + + The object will be used by Spark in the following way. + + * A single copy of this object is responsible of all the data generated by a + single task in a query. In other words, one instance is responsible for + processing one partition of the data generated in a distributed manner. + + * This object must be serializable because each task will get a fresh + serialized-deserialized copy of the provided object. Hence, it is strongly + recommended that any initialization for writing data (e.g. opening a + connection or starting a transaction) is done after the `open(...)` + method has been called, which signifies that the task is ready to generate data. + + * The lifecycle of the methods are as follows. + + For each partition with ``partition_id``: + + ... For each batch/epoch of streaming data with ``epoch_id``: + + ....... Method ``open(partitionId, epochId)`` is called. + + ....... If ``open(...)`` returns true, for each row in the partition and + batch/epoch, method ``process(row)`` is called. + + ....... Method ``close(errorOrNull)`` is called with error (if any) seen while + processing rows. + + Important points to note: + + * The `partitionId` and `epochId` can be used to deduplicate generated data when + failures cause reprocessing of some input data. This depends on the execution + mode of the query. If the streaming query is being executed in the micro-batch + mode, then every partition represented by a unique tuple (partition_id, epoch_id) + is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + to deduplicate and/or transactionally commit data and achieve exactly-once + guarantees. However, if the streaming query is being executed in the continuous + mode, then this guarantee does not hold and therefore should not be used for + deduplication. + + * The ``close()`` method (if exists) will be called if `open()` method exists and + returns successfully (irrespective of the return value), except if the Python + crashes in the middle. + + .. note:: Evolving. + + >>> # Print every row using a function + >>> def print_row(row): + ... print(row) + ... + >>> writer = sdf.writeStream.foreach(print_row) + >>> # Print every row using a object with process() method + >>> class RowPrinter: + ... def open(self, partition_id, epoch_id): + ... print("Opened %d, %d" % (partition_id, epoch_id)) + ... return True + ... def process(self, row): + ... print(row) + ... def close(self, error): + ... print("Closed with error: %s" % str(error)) + ... + >>> writer = sdf.writeStream.foreach(RowPrinter()) + """ + + from pyspark.rdd import _wrap_function + from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.taskcontext import TaskContext + + if callable(f): + # The provided object is a callable function that is supposed to be called on each row. + # Construct a function that takes an iterator and calls the provided function on each + # row. + def func_without_process(_, iterator): + for x in iterator: + f(x) + return iter([]) + + func = func_without_process + + else: + # The provided object is not a callable function. Then it is expected to have a + # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and + # 'close(error)' methods. + + if not hasattr(f, 'process'): + raise Exception("Provided object does not have a 'process' method") + + if not callable(getattr(f, 'process')): + raise Exception("Attribute 'process' in provided object is not callable") + + def doesMethodExist(method_name): + exists = hasattr(f, method_name) + if exists and not callable(getattr(f, method_name)): + raise Exception( + "Attribute '%s' in provided object is not callable" % method_name) + return exists + + open_exists = doesMethodExist('open') + close_exists = doesMethodExist('close') + + def func_with_open_process_close(partition_id, iterator): + epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if epoch_id: + epoch_id = int(epoch_id) + else: + raise Exception("Could not get batch id from TaskContext") + + # Check if the data should be processed + should_process = True + if open_exists: + should_process = f.open(partition_id, epoch_id) + + error = None + + try: + if should_process: + for x in iterator: + f.process(x) + except Exception as ex: + error = ex + finally: + if close_exists: + f.close(error) + if error: + raise error + + return iter([]) + + func = func_with_open_process_close + + serializer = AutoBatchedSerializer(PickleSerializer()) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) + jForeachWriter = \ + self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func, self._df._jdf.schema()) + self._jwrite.foreach(jForeachWriter) + return self + + @since(2.4) + def foreachBatch(self, func): + """ + Sets the output of the streaming query to be processed using the provided + function. This is supported only the in the micro-batch execution modes (that is, when the + trigger is not continuous). In every micro-batch, the provided function will be called in + every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier. + The batchId can be used deduplicate and transactionally write the output + (that is, the provided Dataset) to external systems. The output DataFrame is guaranteed + to exactly same for the same batchId (assuming all operations are deterministic in the + query). + + .. note:: Evolving. + + >>> def func(batch_df, batch_id): + ... batch_df.collect() + ... + >>> writer = sdf.writeStream.foreach(func) + """ + + from pyspark.java_gateway import ensure_callback_server_started + gw = self._spark._sc._gateway + java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*") + + wrapped_func = ForeachBatchFunction(self._spark, func) + gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func) + ensure_callback_server_started(gw) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e99c8e3c6b10..81c0af0b3d81b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -68,8 +68,16 @@ # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) +_test_not_compiled_message = None +try: + from pyspark.sql.utils import require_test_compiled + require_test_compiled() +except Exception as e: + _test_not_compiled_message = _exception_message(e) + _have_pandas = _pandas_requirement_message is None _have_pyarrow = _pyarrow_requirement_message is None +_test_compiled = _test_not_compiled_message is None from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row @@ -685,6 +693,13 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_encoding_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, encoding="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + def test_linesep_json(self): df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") expected = [Row(_corrupt_record=None, name=u'Michael'), @@ -758,7 +773,7 @@ def filename(path): row2 = df2.select(sameText(df2['file'])).first() self.assertTrue(row2[0].find("people.json") != -1) - def test_udf_defers_judf_initalization(self): + def test_udf_defers_judf_initialization(self): # This is separate of UDFInitializationTests # to avoid context initialization # when udf is called @@ -1862,6 +1877,299 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert msg in str(e), "%s not in %s" % (msg, str(e)) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.sql.utils import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise Exception("this should fail the query") + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -3018,9 +3326,191 @@ def test_sort_with_nulls_order(self): df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) + def test_json_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) + schema = self.spark.read.option('inferSchema', True) \ + .option('samplingRatio', 0.5) \ + .json(rdd).schema + self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + + def test_csv_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '0.1' if x == 1 else str(x)) + schema = self.spark.read.option('inferSchema', True)\ + .csv(rdd, samplingRatio=0.5).schema + self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + + def test_checking_csv_header(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + self.spark.createDataFrame([[1, 1000], [2000, 2]])\ + .toDF('f1', 'f2').write.option("header", "true").csv(path) + schema = StructType([ + StructField('f2', IntegerType(), nullable=True), + StructField('f1', IntegerType(), nullable=True)]) + df = self.spark.read.option('header', 'true').schema(schema)\ + .csv(path, enforceSchema=False) + self.assertRaisesRegexp( + Exception, + "CSV header does not conform to the schema", + lambda: df.collect()) + finally: + shutil.rmtree(path) + + def test_ignore_column_of_all_nulls(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""], + ["""{"a":null, "b":null, "c":"string"}"""], + ["""{"a":null, "b":null, "c":null}"""]]) + df.write.text(path) + schema = StructType([ + StructField('b', LongType(), nullable=True), + StructField('c', StringType(), nullable=True)]) + readback = self.spark.read.json(path, dropFieldIfAllNull=True) + self.assertEquals(readback.schema, schema) + finally: + shutil.rmtree(path) + + # SPARK-24721 + @unittest.skipIf(not _test_compiled, _test_not_compiled_message) + def test_datasource_with_udf(self): + from pyspark.sql.functions import udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = udf(lambda x: x + 1, 'int')(lit(1)) + c2 = udf(lambda x: x + 1, 'int')(col('i')) + + f1 = udf(lambda x: False, 'boolean')(lit(1)) + f2 = udf(lambda x: False, 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + + def test_repr_behaviors(self): + import re + pattern = re.compile(r'^ *\|', re.MULTILINE) + df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) + + # test when eager evaluation is enabled and _repr_html_ will not be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """+-----+-----+ + || key|value| + |+-----+-----+ + || 1| 1| + ||22222|22222| + |+-----+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + ||222| 222| + |+---+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) + + # test when eager evaluation is enabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """ + | + | + | + |
      keyvalue
      11
      2222222222
      + |""" + self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """ + | + | + | + |
      keyvalue
      11
      222222
      + |""" + self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """ + | + | + |
      keyvalue
      11
      + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + + # test when eager evaluation is disabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): + expected = "DataFrame[key: bigint, value: string]" + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + class HiveSparkSubmitTests(SparkSubmitTests): + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + def test_hivecontext(self): # This test checks that HiveContext is using Hive metastore (SPARK-16224). # It sets a metastore url and checks if there is a derby dir created by @@ -3050,8 +3540,8 @@ def test_hivecontext(self): |print(hive_context.sql("show databases").collect()) """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -3088,23 +3578,28 @@ def setUpClass(cls): filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "TestQueryExecutionListener.class") - if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): - raise unittest.SkipTest( + cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) + + if cls.has_listener: + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + def setUp(self): + if not self.has_listener: + raise self.skipTest( "'org.apache.spark.sql.TestQueryExecutionListener' is not " "available. Will skip the related tests.") - # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.queryExecutionListeners", - "org.apache.spark.sql.TestQueryExecutionListener") \ - .getOrCreate() - @classmethod def tearDownClass(cls): - cls.spark.stop() + if hasattr(cls, "spark"): + cls.spark.stop() def tearDown(self): self.spark._jvm.OnSuccessCall.clear() @@ -3165,9 +3660,9 @@ def tearDown(self): SparkSession._instantiatedSession.stop() if SparkContext._active_spark_context is not None: - SparkContext._active_spark_contex.stop() + SparkContext._active_spark_context.stop() - def test_udf_init_shouldnt_initalize_context(self): + def test_udf_init_shouldnt_initialize_context(self): from pyspark.sql.functions import UserDefinedFunction UserDefinedFunction(lambda x: x, StringType()) @@ -3188,18 +3683,22 @@ class HiveContextSQLTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False os.unlink(cls.tempdir.name) - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -3600,6 +4099,8 @@ class ArrowTests(ReusedSQLTestCase): def setUpClass(cls): from datetime import date, datetime from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -3628,6 +4129,13 @@ def setUpClass(cls): (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): + cls.schema.add(StructField("9_binary_t", BinaryType(), True)) + cls.data[0] = cls.data[0] + (bytearray(b"a"),) + cls.data[1] = cls.data[1] + (bytearray(b"bb"),) + cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + @classmethod def tearDownClass(cls): del os.environ["TZ"] @@ -3665,12 +4173,23 @@ def test_toPandas_fallback_enabled(self): self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): + from distutils.version import LooseVersion + import pyarrow as pa + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + schema = StructType([StructField("binary", BinaryType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'): + df.toPandas() + def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -3782,19 +4301,22 @@ def test_createDataFrame_with_schema(self): def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() - wrong_schema = StructType(list(reversed(self.schema))) + fields = list(self.schema) + fields[0], fields[7] = fields[7], fields[0] # swap str with timestamp + wrong_schema = StructType(fields) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() + new_names = list(map(str, range(len(self.schema.fieldNames())))) # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) - self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) + df = self.spark.createDataFrame(pdf, schema=list(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) - self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) + df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -3881,13 +4403,22 @@ def test_createDataFrame_fallback_enabled(self): self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): + from distutils.version import LooseVersion import pandas as pd + import pyarrow as pa with QuietTest(self.sc): with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map") + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'): + self.spark.createDataFrame( + pd.DataFrame([[{'a': b'aaa'}]]), "a: binary") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd @@ -4029,6 +4560,61 @@ def foo(df): def foo(k, v, w): return k + def test_stopiteration_in_udf(self): + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn('v', udf(foo)('id')).collect + ) + + # pandas scalar udf + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect + ) + + # pandas grouped map + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + # pandas grouped agg + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).collect + ) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4224,6 +4810,24 @@ def test_vectorized_udf_datatype_string(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_binary(self): + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, col + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) + else: + data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)] + schema = StructType().add("binary", BinaryType()) + df = self.spark.createDataFrame(data, schema) + str_f = pandas_udf(lambda x: x, BinaryType()) + res = df.select(str_f(col('binary'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_array_type(self): from pyspark.sql.functions import pandas_udf, col data = [([1, 2],), ([3, 4],)] @@ -4274,17 +4878,6 @@ def test_vectorized_udf_invalid_length(self): 'Result vector from pandas_udf was not the required length'): df.select(raise_exception(col('id'))).collect() - def test_vectorized_udf_mix_udf(self): - from pyspark.sql.functions import pandas_udf, udf, col - df = self.spark.range(10) - row_by_row_udf = udf(lambda x: x, LongType()) - pd_udf = pandas_udf(lambda x: x, LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Can not mix vectorized and non-vectorized UDFs'): - df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() - def test_vectorized_udf_chained(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) @@ -4295,7 +4888,6 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -4342,12 +4934,6 @@ def test_vectorized_udf_unsupported_types(self): 'Invalid returnType.*scalar Pandas UDF.*MapType'): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): - pandas_udf(lambda x: x, BinaryType()) - def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date @@ -4572,6 +5158,211 @@ def test_type_annotation(self): df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) self.assertEqual(df.first()[0], 0) + def test_mixed_udf(self): + import pandas as pd + from pyspark.sql.functions import col, udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of multiple UDFs and Pandas UDFs. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + @pandas_udf('int') + def f2(x): + assert type(x) == pd.Series + return x + 10 + + @udf('int') + def f3(x): + assert type(x) == int + return x + 100 + + @pandas_udf('int') + def f4(x): + assert type(x) == pd.Series + return x + 1000 + + # Test single expression with chained UDFs + df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) + df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) + df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) + df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) + + expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) + expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) + expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) + expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) + expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) + + self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) + self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) + self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) + self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) + self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) + + # Test multiple mixed UDF expressions in a single projection + df_multi_1 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(col('f1'))) \ + .withColumn('f3_f1', f3(col('f1'))) \ + .withColumn('f4_f1', f4(col('f1'))) \ + .withColumn('f3_f2', f3(col('f2'))) \ + .withColumn('f4_f2', f4(col('f2'))) \ + .withColumn('f4_f3', f4(col('f3'))) \ + .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ + .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ + .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ + .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ + .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) + + # Test mixed udfs in a single expression + df_multi_2 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(f1(col('v')))) \ + .withColumn('f3_f1', f3(f1(col('v')))) \ + .withColumn('f4_f1', f4(f1(col('v')))) \ + .withColumn('f3_f2', f3(f2(col('v')))) \ + .withColumn('f4_f2', f4(f2(col('v')))) \ + .withColumn('f4_f3', f4(f3(col('v')))) \ + .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ + .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ + .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ + .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ + .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) + + expected = df \ + .withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f4', df['v'] + 1000) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f4_f1', df['v'] + 1001) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f4_f2', df['v'] + 1010) \ + .withColumn('f4_f3', df['v'] + 1100) \ + .withColumn('f3_f2_f1', df['v'] + 111) \ + .withColumn('f4_f2_f1', df['v'] + 1011) \ + .withColumn('f4_f3_f1', df['v'] + 1101) \ + .withColumn('f4_f3_f2', df['v'] + 1110) \ + .withColumn('f4_f3_f2_f1', df['v'] + 1111) + + self.assertEquals(expected.collect(), df_multi_1.collect()) + self.assertEquals(expected.collect(), df_multi_2.collect()) + + def test_mixed_udf_and_sql(self): + import pandas as pd + from pyspark.sql import Column + from pyspark.sql.functions import udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of UDFs, Pandas UDFs and SQL expression. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + def f2(x): + assert type(x) == Column + return x + 10 + + @pandas_udf('int') + def f3(x): + assert type(x) == pd.Series + return x + 100 + + df1 = df.withColumn('f1', f1(df['v'])) \ + .withColumn('f2', f2(df['v'])) \ + .withColumn('f3', f3(df['v'])) \ + .withColumn('f1_f2', f1(f2(df['v']))) \ + .withColumn('f1_f3', f1(f3(df['v']))) \ + .withColumn('f2_f1', f2(f1(df['v']))) \ + .withColumn('f2_f3', f2(f3(df['v']))) \ + .withColumn('f3_f1', f3(f1(df['v']))) \ + .withColumn('f3_f2', f3(f2(df['v']))) \ + .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ + .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ + .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ + .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ + .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ + .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + + expected = df.withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f1_f2', df['v'] + 11) \ + .withColumn('f1_f3', df['v'] + 101) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f2_f3', df['v'] + 110) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f1_f2_f3', df['v'] + 111) \ + .withColumn('f1_f3_f2', df['v'] + 111) \ + .withColumn('f2_f1_f3', df['v'] + 111) \ + .withColumn('f2_f3_f1', df['v'] + 111) \ + .withColumn('f3_f1_f2', df['v'] + 111) \ + .withColumn('f3_f2_f1', df['v'] + 111) + + self.assertEquals(expected.collect(), df1.collect()) + + # SPARK-24721 + @unittest.skipIf(not _test_compiled, _test_not_compiled_message) + def test_datasource_with_udf(self): + # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF + # This needs to a separate test because Arrow dependency is optional + import pandas as pd + import numpy as np + from pyspark.sql.functions import pandas_udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1)) + c2 = pandas_udf(lambda x: x + 1, 'int')(col('i')) + + f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) + f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4629,6 +5420,26 @@ def test_supported_types(self): self.assertPandasEqual(expected2, result2) self.assertPandasEqual(expected3, result3) + def test_array_type_correct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + + df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType()))]) + + udf = pandas_udf( + lambda pdf: pdf, + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4860,6 +5671,140 @@ def foo3(key, pdf): expected4 = udf3.func((), pdf) self.assertPandasEqual(expected4, result4) + def test_column_order(self): + from collections import OrderedDict + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + # Helper function to set column names from a list + def rename_pdf(pdf, names): + pdf.rename(columns={old: new for old, new in + zip(pd_result.columns, names)}, inplace=True) + + df = self.data + grouped_df = df.groupby('id') + grouped_pdf = df.toPandas().groupby('id') + + # Function returns a pdf with required column names, but order could be arbitrary using dict + def change_col_order(pdf): + # Constructing a DataFrame from a dict should result in the same order, + # but use from_items to ensure the pdf column order is different than schema + return pd.DataFrame.from_items([ + ('id', pdf.id), + ('u', pdf.v * 2), + ('v', pdf.v)]) + + ordered_udf = pandas_udf( + change_col_order, + 'id long, v int, u int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by name from the pdf + result = grouped_df.apply(ordered_udf).sort('id', 'v')\ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(change_col_order) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with positional columns, indexed by range + def range_col_order(pdf): + # Create a DataFrame with positional columns, fix types to long + return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') + + range_udf = pandas_udf( + range_col_order, + 'id long, u long, v long', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result uses positional columns from the pdf + result = grouped_df.apply(range_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(range_col_order) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with columns indexed with integers + def int_index(pdf): + return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) + + int_index_udf = pandas_udf( + int_index, + 'id long, u int, v int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by position of integer index + result = grouped_df.apply(int_index_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(int_index) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def column_name_typo(pdf): + return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def invalid_positional_types(pdf): + return pd.DataFrame([(u'a', 1.2)]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): + grouped_df.apply(column_name_typo).collect() + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + + def test_positional_assignment_conf(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}): + + @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) + def foo(_): + return pd.DataFrame([('hi', 1)], columns=['x', 'y']) + + df = self.data + result = df.groupBy('id').apply(foo).select('a', 'b').collect() + for r in result: + self.assertEqual(r.a, 'hi') + self.assertEqual(r.b, 1) + + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + + def test_mixed_scalar_udfs_followed_by_grouby_apply(self): + import pandas as pd + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + + df = self.spark.range(0, 10).toDF('v1') + df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ + .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) + + result = df.groupby() \ + .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), + 'sum int', + PandasUDFType.GROUPED_MAP)) + + self.assertEquals(result.collect()[0]['sum'], 165) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -5168,8 +6113,8 @@ def test_complex_groupby(self): expected2 = df.groupby().agg(sum(df.v)) # groupby one column and one sql expression - result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) # groupby one python UDF result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) @@ -5280,6 +6225,15 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + def test_invalid_args(self): from pyspark.sql.functions import mean @@ -5305,9 +6259,238 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class WindowPandasUDFTests(ReusedSQLTestCase): + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + return udf(lambda v: v + 1, 'double') + + @property + def pandas_scalar_time_two(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + return pandas_udf(lambda v: v * 2, 'double') + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_max_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def max(v): + return v.max() + return max + + @property + def pandas_agg_min_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def min(v): + return v.min() + return min + + @property + def unbounded_window(self): + return Window.partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + + @property + def ordered_window(self): + return Window.partitionBy('id').orderBy('v') + + @property + def unpartitioned_window(self): + return Window.partitionBy() + + def test_simple(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max + + df = self.data + w = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_multiple_udfs(self): + from pyspark.sql.functions import max, min, mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_replace_existing(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v', mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) + expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + plus_one = self.python_plus_one + time_two = self.pandas_scalar_time_two + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn( + 'v2', + plus_one(mean_udf(plus_one(df['v'])).over(w))) + expected1 = df.withColumn( + 'v2', + plus_one(mean(plus_one(df['v'])).over(w))) + + result2 = df.withColumn( + 'v2', + time_two(mean_udf(time_two(df['v'])).over(w))) + expected2 = df.withColumn( + 'v2', + time_two(mean(time_two(df['v'])).over(w))) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_without_partitionBy(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unpartitioned_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v2', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_sql_and_udf(self): + from pyspark.sql.functions import max, min, rank, col + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) + expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) + + # Test mixing sql window function and window udf in the same expression + result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) + expected2 = expected1 + + # Test chaining sql aggregate function and udf + result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') + expected3 = expected1 + + # Test mixing sql window function and udf + result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.withColumn('v2', array_udf(df['v']).over(w)) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + + def test_invalid_args(self): + from pyspark.sql.functions import mean, pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*not supported within a window function'): + foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) + df.withColumn('v2', foo_udf(df['v']).over(w)) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*Only unbounded window frame is supported.*'): + df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f6534836d64a..0b61707c8cc0a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -206,7 +206,7 @@ class DecimalType(FractionalType): and scale (the number of digits on the right of dot). For example, (5, 2) can support the value from [-999.99 to 999.99]. - The precision can be up to 38, the scale must less or equal to precision. + The precision can be up to 38, the scale must be less or equal to precision. When create a DecimalType, the default precision and scale is (10, 0). When infer schema from decimal.Decimal objects, it will be DecimalType(38, 18). @@ -289,7 +289,8 @@ def __init__(self, elementType, containsNull=True): >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ - assert isinstance(elementType, DataType), "elementType should be DataType" + assert isinstance(elementType, DataType),\ + "elementType %s should be an instance of %s" % (elementType, DataType) self.elementType = elementType self.containsNull = containsNull @@ -343,8 +344,10 @@ def __init__(self, keyType, valueType, valueContainsNull=True): ... == MapType(StringType(), FloatType())) False """ - assert isinstance(keyType, DataType), "keyType should be DataType" - assert isinstance(valueType, DataType), "valueType should be DataType" + assert isinstance(keyType, DataType),\ + "keyType %s should be an instance of %s" % (keyType, DataType) + assert isinstance(valueType, DataType),\ + "valueType %s should be an instance of %s" % (valueType, DataType) self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull @@ -402,8 +405,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None): ... == StructField("f2", StringType(), True)) False """ - assert isinstance(dataType, DataType), "dataType should be DataType" - assert isinstance(name, basestring), "field name should be string" + assert isinstance(dataType, DataType),\ + "dataType %s should be an instance of %s" % (dataType, DataType) + assert isinstance(name, basestring), "field name %s should be string" % (name) if not isinstance(name, str): name = name.encode('utf-8') self.name = name @@ -1574,6 +1578,7 @@ def convert(self, obj, gateway_client): def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ + from distutils.version import LooseVersion import pyarrow as pa if type(dt) == BooleanType: arrow_type = pa.bool_() @@ -1593,6 +1598,12 @@ def to_arrow_type(dt): arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == BinaryType: + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt) + + "\nPlease install pyarrow >= 0.10.0 for BinaryType support.") + arrow_type = pa.binary() elif type(dt) == DateType: arrow_type = pa.date32() elif type(dt) == TimestampType: @@ -1619,6 +1630,8 @@ def to_arrow_schema(schema): def from_arrow_type(at): """ Convert pyarrow type to Spark data type. """ + from distutils.version import LooseVersion + import pyarrow as pa import pyarrow.types as types if types.is_boolean(at): spark_type = BooleanType() @@ -1638,6 +1651,12 @@ def from_arrow_type(at): spark_type = DecimalType(precision=at.precision, scale=at.scale) elif types.is_string(at): spark_type = StringType() + elif types.is_binary(at): + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at) + + "\nPlease install pyarrow >= 0.10.0 for BinaryType support.") + spark_type = BinaryType() elif types.is_date32(at): spark_type = DateType() elif types.is_timestamp(at): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 45363f089a73d..bdb3a1467f1d8 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -150,3 +150,45 @@ def require_minimum_pyarrow_version(): if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): raise ImportError("PyArrow >= %s must be installed; however, " "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) + + +def require_test_compiled(): + """ Raise Exception if test classes are not compiled + """ + import os + import glob + try: + spark_home = os.environ['SPARK_HOME'] + except KeyError: + raise RuntimeError('SPARK_HOME is not defined in environment') + + test_class_path = os.path.join( + spark_home, 'sql', 'core', 'target', '*', 'test-classes') + paths = glob.glob(test_class_path) + + if len(paths) == 0: + raise RuntimeError( + "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path) + + +class ForeachBatchFunction(object): + """ + This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps + the user-defined 'foreachBatch' function such that it can be called from the JVM when + the query is active. + """ + + def __init__(self, sql_ctx, func): + self.sql_ctx = sql_ctx + self.func = func + + def call(self, jdf, batch_id): + from pyspark.sql.dataframe import DataFrame + try: + self.func(DataFrame(jdf, self.sql_ctx), batch_id) + except Exception as e: + self.error = e + raise e + + class Java: + implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 17c34f8a1c54c..3fa57ca85b37b 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -79,22 +79,8 @@ def _ensure_initialized(cls): java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") - # start callback server - # getattr will fallback to JVM, so we cannot test by hasattr() - if "_callback_server" not in gw.__dict__ or gw._callback_server is None: - gw.callback_server_parameters.eager_load = True - gw.callback_server_parameters.daemonize = True - gw.callback_server_parameters.daemonize_connections = True - gw.callback_server_parameters.port = 0 - gw.start_callback_server(gw.callback_server_parameters) - cbport = gw._callback_server.server_socket.getsockname()[1] - gw._callback_server.port = cbport - # gateway with real port - gw._python_proxy_port = gw._callback_server.port - # get the GatewayServer object in JVM by ID - jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) - # update the port of CallbackClient with real port - jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) + from pyspark.java_gateway import ensure_callback_server_started + ensure_callback_server_started(gw) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing @@ -236,7 +222,7 @@ def remember(self, duration): Set each DStreams in this context to remember RDDs it generated in the last given duration. DStreams remember RDDs only for a limited duration of time and releases them for garbage collection. - This method allows the developer to specify how to long to remember + This method allows the developer to specify how long to remember the RDDs (if the developer wishes to query old data outside the DStream computation). @@ -301,7 +287,7 @@ def _check_serializers(self, rdds): def queueStream(self, rdds, oneAtATime=True, default=None): """ - Create an input stream from an queue of RDDs or list. In each batch, + Create an input stream from a queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. .. note:: Changes to the queue after the stream is created will not be recognized. @@ -338,7 +324,7 @@ def transform(self, dstreams, transformFunc): jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, - lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + lambda t, *rdds: transformFunc(rdds), *[d._jrdd_deserializer for d in dstreams]) jfunc = self._jvm.TransformFunction(func) jdstream = self._jssc.transform(jdstreams, jfunc) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 59977dcb435a8..ce42a857d0c06 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -23,6 +23,8 @@ if sys.version < "3": from itertools import imap as map, ifilter as filter +else: + long = int from py4j.protocol import Py4JJavaError diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 103940923dd4d..5cef621a28e6e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -63,7 +63,7 @@ def setUpClass(cls): class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir("/tmp") + cls.sc.setCheckpointDir(tempfile.mkdtemp()) @classmethod def tearDownClass(cls): @@ -179,7 +179,7 @@ def func(dstream): self._test_func(input, func, expected) def test_flatMap(self): - """Basic operation test for DStream.faltMap.""" + """Basic operation test for DStream.flatMap.""" input = [range(1, 5), range(5, 9), range(9, 13)] def func(dstream): @@ -206,6 +206,38 @@ def func(dstream): expected = [[len(x)] for x in input] self._test_func(input, func, expected) + def test_slice(self): + """Basic operation test for DStream.slice.""" + import datetime as dt + self.ssc = StreamingContext(self.sc, 1.0) + self.ssc.remember(4.0) + input = [[1], [2], [3], [4]] + stream = self.ssc.queueStream([self.sc.parallelize(d, 1) for d in input]) + + time_vals = [] + + def get_times(t, rdd): + if rdd and len(time_vals) < len(input): + time_vals.append(t) + + stream.foreachRDD(get_times) + + self.ssc.start() + self.wait_for(time_vals, 4) + begin_time = time_vals[0] + + def get_sliced(begin_delta, end_delta): + begin = begin_time + dt.timedelta(seconds=begin_delta) + end = begin_time + dt.timedelta(seconds=end_delta) + rdds = stream.slice(begin, end) + result_list = [rdd.collect() for rdd in rdds] + return [r for result in result_list for r in result] + + self.assertEqual(set([1]), set(get_sliced(0, 0))) + self.assertEqual(set([2, 3]), set(get_sliced(1, 2))) + self.assertEqual(set([2, 3, 4]), set(get_sliced(1, 4))) + self.assertEqual(set([1, 2, 3, 4]), set(get_sliced(0, 4))) + def test_reduce(self): """Basic operation test for DStream.reduce.""" input = [range(1, 5), range(5, 9), range(9, 13)] @@ -779,6 +811,12 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + def test_get_active(self): self.assertEqual(StreamingContext.getActive(), None) @@ -816,7 +854,7 @@ def setupFunc(): self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) self.assertTrue(self.setupCalled) - # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + # Verify that getActiveOrCreate() returns active context and does not call the setupFunc self.ssc.start() self.setupCalled = False self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) @@ -1549,7 +1587,9 @@ def search_kinesis_asl_assembly_jar(): kinesis_jar_present = True jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % jars + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, StreamingListenerTests] @@ -1590,11 +1630,11 @@ def search_kinesis_asl_assembly_jar(): sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) if not result.wasSuccessful(): failed = True else: - result = unittest.TextTestRunner(verbosity=3).run(tests) + result = unittest.TextTestRunner(verbosity=2).run(tests) if not result.wasSuccessful(): failed = True sys.exit(failed) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index df184471993ff..b4b9f97feb7ca 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -20,6 +20,8 @@ import traceback import sys +from py4j.java_gateway import is_instance_of + from pyspark import SparkContext, RDD @@ -65,7 +67,14 @@ def call(self, milliseconds, jrdds): t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) if r: - return r._jrdd + # Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`. + # org.apache.spark.streaming.api.python.PythonTransformFunction requires to return + # `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`. + # See SPARK-17756. + if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"): + return r._jrdd + else: + return r.map(lambda x: x)._jrdd except: self.failure = traceback.format_exc() diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index e5218d9e75e78..53fc2b29e066f 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,6 +16,10 @@ # from __future__ import print_function +import socket + +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import write_int, UTF8Deserializer class TaskContext(object): @@ -34,6 +38,7 @@ class TaskContext(object): _partitionId = None _stageId = None _taskAttemptId = None + _localProperties = None def __new__(cls): """Even if users construct TaskContext instead of using get, give them the singleton.""" @@ -88,3 +93,125 @@ def taskAttemptId(self): TaskAttemptID. """ return self._taskAttemptId + + def getLocalProperty(self, key): + """ + Get a local property set upstream in the driver, or None if it is missing. + """ + return self._localProperties.get(key, None) + + +BARRIER_FUNCTION = 1 + + +def _load_from_socket(port, auth_secret): + """ + Load data from a given socket, this is a blocking method thus only return when the socket + connection has been closed. + """ + (sockfile, sock) = local_connect_and_auth(port, auth_secret) + # The barrier() call may block forever, so no timeout + sock.settimeout(None) + # Make a barrier() function call. + write_int(BARRIER_FUNCTION, sockfile) + sockfile.flush() + + # Collect result. + res = UTF8Deserializer().loads(sockfile) + + # Release resources. + sockfile.close() + sock.close() + + return res + + +class BarrierTaskContext(TaskContext): + + """ + .. note:: Experimental + + A TaskContext with extra info and tooling for a barrier stage. To access the BarrierTaskContext + for a running task, use: + L{BarrierTaskContext.get()}. + + .. versionadded:: 2.4.0 + """ + + _port = None + _secret = None + + def __init__(self): + """Construct a BarrierTaskContext, use get instead""" + pass + + @classmethod + def _getOrCreate(cls): + """Internal function to get or create global BarrierTaskContext.""" + if cls._taskContext is None: + cls._taskContext = BarrierTaskContext() + return cls._taskContext + + @classmethod + def get(cls): + """ + Return the currently active BarrierTaskContext. This can be called inside of user functions + to access contextual information about running tasks. + + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. + """ + return cls._taskContext + + @classmethod + def _initialize(cls, port, secret): + """ + Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called + after BarrierTaskContext is initialized. + """ + cls._port = port + cls._secret = secret + + def barrier(self): + """ + .. note:: Experimental + + Sets a global barrier and waits until all tasks in this stage hit this barrier. + Note this method is only allowed for a BarrierTaskContext. + + .. versionadded:: 2.4.0 + """ + if self._port is None or self._secret is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + _load_from_socket(self._port, self._secret) + + def getTaskInfos(self): + """ + .. note:: Experimental + + Returns the all task infos in this barrier stage, the task infos are ordered by + partitionId. + Note this method is only allowed for a BarrierTaskContext. + + .. versionadded:: 2.4.0 + """ + if self._port is None or self._secret is None: + raise Exception("Not supported to call getTaskInfos() before initialize " + + "BarrierTaskContext.") + else: + addresses = self._localProperties.get("addresses", "") + return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")] + + +class BarrierTaskInfo(object): + """ + .. note:: Experimental + + Carries all task infos of a barrier task. + + .. versionadded:: 2.4.0 + """ + + def __init__(self, address): + self.address = address diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9111dbbed5929..8ac1df52fc597 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -70,7 +70,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler -from pyspark.taskcontext import TaskContext +from pyspark.taskcontext import BarrierTaskContext, TaskContext _have_scipy = False _have_numpy = False @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -543,6 +574,54 @@ def test_tc_on_driver(self): tc = TaskContext.get() self.assertTrue(tc is None) + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + + def test_barrier(self): + """ + Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks + within a stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return time.time() + + times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + self.assertTrue(max(times) - min(times) < 1) + + def test_barrier_infos(self): + """ + Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the + barrier stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() + .getTaskInfos()).collect() + self.assertTrue(len(taskInfos) == 4) + self.assertTrue(len(taskInfos[0]) == 4) + class RDDTests(ReusedPySparkTestCase): @@ -1246,6 +1325,35 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_user_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): @@ -1951,7 +2059,12 @@ class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() - self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] def tearDown(self): shutil.rmtree(self.programDir) @@ -2017,7 +2130,7 @@ def test_single_script(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 4, 6]", out.decode('utf-8')) @@ -2033,7 +2146,7 @@ def test_script_with_local_functions(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[3, 6, 9]", out.decode('utf-8')) @@ -2051,7 +2164,7 @@ def test_module_dependency(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2070,7 +2183,7 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() @@ -2087,8 +2200,10 @@ def test_package_dependency(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2103,9 +2218,11 @@ def test_package_dependency_on_cluster(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", - "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2124,7 +2241,7 @@ def test_single_script_on_cluster(self): # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2144,7 +2261,7 @@ def test_user_configuration(self): | sc.stop() """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local", script], + self.sparkSubmit + ["--master", "local", script], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, err = proc.communicate() @@ -2303,6 +2420,10 @@ def test_py4j_exception_message(self): self.assertTrue('NullPointerException' in _exception_message(context.exception)) + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -2353,15 +2474,7 @@ def test_statcounter_array(self): if __name__ == "__main__": from pyspark.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if not _have_numpy: - print("NOTE: Skipping NumPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - if not _have_numpy: - print("NOTE: NumPy tests were skipped as it does not seem to be installed") + unittest.main(verbosity=2) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 49afc13640332..f015542c8799d 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import re import sys import inspect from py4j.protocol import Py4JJavaError @@ -52,15 +53,59 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. if sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec +class VersionUtils(object): + """ + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(sparkVersion): + """ + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). + + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 4) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 3) + + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: + return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") + + +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop in Spark code + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a1a4336b1e8de..e934da4d2eb6e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -22,21 +22,26 @@ import os import sys import time +import resource import socket import traceback from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.taskcontext import TaskContext +from pyspark.java_gateway import local_connect_and_auth +from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType -from pyspark.serializers import write_with_length, write_int, read_long, \ +from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle +if sys.version >= '3': + basestring = str + pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -81,7 +86,7 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " + raise TypeError("Return type of the user-defined function should be " "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " @@ -91,10 +96,12 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): + assign_cols_by_pos = runner_conf.get( + "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False) + def wrapped(key_series, value_series): import pandas as pd - argspec = _get_argspec(f) if len(argspec.args) == 1: result = f(pd.concat(value_series, axis=1)) @@ -110,9 +117,13 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] + + # Assign result columns by schema name if user labeled with strings, else use position + if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns): + return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] + else: + return [(result[result.columns[i]], to_arrow_type(field.dataType)) + for i, field in enumerate(return_type)] return wrapped @@ -128,7 +139,22 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type): +def wrap_window_agg_pandas_udf(f, return_type): + # This is similar to grouped_agg_pandas_udf, the only difference + # is that window_agg_pandas_udf needs to repeat the return value + # to match window length, where grouped_agg_pandas_udf just returns + # the scalar value. + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series([result]).repeat(len(series[0])) + + return lambda *a: (wrapped(*a), arrow_return_type) + + +def read_single_udf(pickleSer, infile, eval_type, runner_conf): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -139,20 +165,47 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) + # make sure StopIteration's raised in the user code are not ignored + # when they are processed in a for loop, raise them as RuntimeError's instead + func = fail_on_stopiteration(row_func) + # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + argspec = _get_argspec(row_func) # signature was lost when wrapping it + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: - return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) + return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: + return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: - return arg_offsets, wrap_udf(row_func, return_type) + return arg_offsets, wrap_udf(func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type)) def read_udfs(pickleSer, infile, eval_type): + runner_conf = {} + + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): + + # Load conf used for pandas_udf evaluation + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + timezone = runner_conf.get("spark.sql.session.timeZone", None) + ser = ArrowStreamPandasSerializer(timezone) + else: + ser = BatchedSerializer(PickleSerializer(), 100) + num_udfs = read_int(infile) udfs = {} call_udf = [] @@ -167,7 +220,7 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] @@ -179,7 +232,7 @@ def read_udfs(pickleSer, infile, eval_type): # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) @@ -188,14 +241,6 @@ def read_udfs(pickleSer, infile, eval_type): mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - timezone = utf8_deserializer.loads(infile) - ser = ArrowStreamPandasSerializer(timezone) - else: - ser = BatchedSerializer(PickleSerializer(), 100) - # profiling is not supported for UDF return func, None, ser, ser @@ -215,12 +260,50 @@ def main(infile, outfile): "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) + # read inputs only for a barrier task + isBarrier = read_bool(infile) + boundPort = read_int(infile) + secret = UTF8Deserializer().loads(infile) + + # set up memory limits + memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1")) + total_memory = resource.RLIMIT_AS + try: + if memory_limit_mb > 0: + (soft_limit, hard_limit) = resource.getrlimit(total_memory) + msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) + print(msg, file=sys.stderr) + + # convert to bytes + new_limit = memory_limit_mb * 1024 * 1024 + + if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: + msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) + print(msg, file=sys.stderr) + resource.setrlimit(total_memory, (new_limit, new_limit)) + + except (resource.error, OSError, ValueError) as e: + # not all systems support resource limits, so warn instead of failing + print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr) + # initialize global state - taskContext = TaskContext._getOrCreate() + taskContext = None + if isBarrier: + taskContext = BarrierTaskContext._getOrCreate() + BarrierTaskContext._initialize(boundPort, secret) + else: + taskContext = TaskContext._getOrCreate() + # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) + taskContext._localProperties = dict() + for i in range(read_int(infile)): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + taskContext._localProperties[k] = v + shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() @@ -301,9 +384,8 @@ def process(): if __name__ == '__main__': - # Read a local port to connect to from stdin - java_port = int(sys.stdin.readline()) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("rwb", 65536) + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) main(sock_file, sock_file) diff --git a/python/run-tests.py b/python/run-tests.py index 6b41b5ee22814..4c90926cfa350 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -22,16 +22,19 @@ from optparse import OptionParser import os import re +import shutil import subprocess import sys import tempfile from threading import Thread, Lock import time +import uuid if sys.version < '3': import Queue else: import queue as Queue from distutils.version import LooseVersion +from multiprocessing import Manager # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -50,6 +53,7 @@ def print_red(text): print('\033[31m' + text + '\033[0m') +SKIPPED_TESTS = Manager().dict() LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() @@ -66,7 +70,7 @@ def print_red(text): raise Exception("Cannot find assembly build directory, please build Spark first.") -def run_individual_python_test(test_name, pyspark_python): +def run_individual_python_test(target_dir, test_name, pyspark_python): env = dict(os.environ) env.update({ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, @@ -75,6 +79,23 @@ def run_individual_python_test(test_name, pyspark_python): 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) }) + + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is + # recognized by the tempfile module to override the default system temp directory. + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + while os.path.isdir(tmp_dir): + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + os.mkdir(tmp_dir) + env["TMPDIR"] = tmp_dir + + # Also override the JVM's temp directory by setting driver and executor options. + spark_args = [ + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "pyspark-shell" + ] + env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) + LOGGER.info("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -82,6 +103,7 @@ def run_individual_python_test(test_name, pyspark_python): retcode = subprocess.Popen( [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], stderr=per_test_output, stdout=per_test_output, env=env).wait() + shutil.rmtree(tmp_dir, ignore_errors=True) except: LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if @@ -109,8 +131,34 @@ def run_individual_python_test(test_name, pyspark_python): # this code is invoked from a thread other than the main thread. os._exit(-1) else: - per_test_output.close() - LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + skipped_counts = 0 + try: + per_test_output.seek(0) + # Here expects skipped test output from unittest when verbosity level is + # 2 (or --verbose option is enabled). + decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) + skipped_tests = list(filter( + lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + decoded_lines)) + skipped_counts = len(skipped_tests) + if skipped_counts > 0: + key = (pyspark_python, test_name) + SKIPPED_TESTS[key] = skipped_tests + per_test_output.close() + except: + import traceback + print_red("\nGot an exception while trying to store " + "skipped test output:\n%s" % traceback.format_exc()) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + if skipped_counts != 0: + LOGGER.info( + "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name, + duration, skipped_counts) + else: + LOGGER.info( + "Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): @@ -152,65 +200,17 @@ def parse_opts(): return opts -def _check_dependencies(python_exec, modules_to_test): - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - try: - subprocess_check_output( - [python_exec, "-c", "import coverage"], - stderr=open(os.devnull, 'w')) - except: - print_red("Coverage is not installed in Python executable '%s' " - "but 'COVERAGE_PROCESS_START' environment variable is set, " - "exiting." % python_exec) - sys.exit(-1) - - # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and - # explicitly prints out. See SPARK-23300. - if pyspark_sql in modules_to_test: - # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. - minimum_pyarrow_version = '0.8.0' - minimum_pandas_version = '0.19.2' - - try: - pyarrow_version = subprocess_check_output( - [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): - LOGGER.info("Will test PyArrow related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) - except: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) - - try: - pandas_version = subprocess_check_output( - [python_exec, "-c", "import pandas; print(pandas.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): - LOGGER.info("Will test Pandas related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) - except: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) +def _check_coverage(python_exec): + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) def main(): @@ -237,9 +237,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - # Check if the python executable has proper dependencies installed to run tests - # for given modules properly. - _check_dependencies(python_exec, modules_to_test) + # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START' + # environmental variable is set. + if "COVERAGE_PROCESS_START" in os.environ: + _check_coverage(python_exec) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], @@ -257,6 +258,11 @@ def main(): priority = 100 task_queue.put((priority, (python_exec, test_goal))) + # Create the target directory before starting tasks to avoid races. + target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) + if not os.path.isdir(target_dir): + os.mkdir(target_dir) + def process_queue(task_queue): while True: try: @@ -264,7 +270,7 @@ def process_queue(task_queue): except Queue.Empty: break try: - run_individual_python_test(test_goal, python_exec) + run_individual_python_test(target_dir, test_goal, python_exec) finally: task_queue.task_done() @@ -281,6 +287,12 @@ def process_queue(task_queue): total_duration = time.time() - start_time LOGGER.info("Tests passed in %i seconds", total_duration) + for key, lines in sorted(SKIPPED_TESTS.items()): + pyspark_python, test_name = key + LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python)) + for line in lines: + LOGGER.info(" %s" % line.rstrip()) + if __name__ == "__main__": main() diff --git a/python/setup.py b/python/setup.py index 794ceceae3008..c447f2d40343d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -34,7 +34,7 @@ print("Failed to load PySpark version file for packaging. You must be in Spark's python dir.", file=sys.stderr) sys.exit(-1) -VERSION = __version__ +VERSION = __version__ # noqa # A temporary path so we can access above the Python project root and fetch scripts and jars we need TEMP_PATH = "deps" SPARK_HOME = os.path.abspath("../") @@ -201,7 +201,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.6'], + install_requires=['py4j==0.10.7'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], @@ -219,6 +219,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json new file mode 100644 index 0000000000000..9c657fa30ac9c Binary files /dev/null and b/python/test_support/sql/people_array_utf16le.json differ diff --git a/repl/pom.xml b/repl/pom.xml index 6f4a863c48bc7..553d5eb79a256 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -102,7 +102,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded @@ -166,15 +166,5 @@
      - - - - scala-2.12 - - scala-2.12/src/main/scala - scala-2.12/src/test/scala - - - diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala deleted file mode 100644 index e69441a475e9a..0000000000000 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import java.io.BufferedReader - -// scalastyle:off println -import scala.Predef.{println => _, _} -// scalastyle:on println -import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} -import scala.tools.nsc.util.stringFromStream -import scala.util.Properties.{javaVersion, javaVmName, versionString} - -/** - * A Spark-specific interactive shell. - */ -class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) - extends ILoop(in0, out) { - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) - def this() = this(None, new JPrintWriter(Console.out, true)) - - override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out) - } - - val initializationCommands: Seq[String] = Seq( - """ - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """, - "import org.apache.spark.SparkContext._", - "import spark.implicits._", - "import spark.sql", - "import org.apache.spark.sql.functions._" - ) - - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(processLine) - } - } - } - - /** Print a welcome message */ - override def printWelcome() { - import org.apache.spark.SPARK_VERSION - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } - - /** Available commands */ - override def commands: List[LoopCommand] = standardCommands - - /** - * We override `loadFiles` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def loadFiles(settings: Settings): Unit = { - initializeSpark() - super.loadFiles(settings) - } - - override def resetCommand(line: String): Unit = { - super.resetCommand(line) - initializeSpark() - echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") - } - - override def replay(): Unit = { - initializeSpark() - super.replay() - } - -} - -object SparkILoop { - - /** - * Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ - def run(code: String, sets: Settings = new Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) - - if (sets.classpath.isDefault) { - sets.classpath.value = sys.props("java.class.path") - } - repl process sets - } - } - } - def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) -} diff --git a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala deleted file mode 100644 index ffb2e5f5db7e2..0000000000000 --- a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import java.io.BufferedReader - -import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} -import scala.tools.nsc.util.stringFromStream -import scala.util.Properties.{javaVersion, javaVmName, versionString} - -/** - * A Spark-specific interactive shell. - */ -class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) - extends ILoop(in0, out) { - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) - def this() = this(None, new JPrintWriter(Console.out, true)) - - val initializationCommands: Seq[String] = Seq( - """ - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """, - "import org.apache.spark.SparkContext._", - "import spark.implicits._", - "import spark.sql", - "import org.apache.spark.sql.functions._" - ) - - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(command) - } - } - } - - /** Print a welcome message */ - override def printWelcome() { - import org.apache.spark.SPARK_VERSION - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } - - /** Available commands */ - override def commands: List[LoopCommand] = standardCommands - - /** - * We override `createInterpreter` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def createInterpreter(): Unit = { - super.createInterpreter() - initializeSpark() - } - - override def resetCommand(line: String): Unit = { - super.resetCommand(line) - initializeSpark() - echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") - } - - override def replay(): Unit = { - initializeSpark() - super.replay() - } - -} - -object SparkILoop { - - /** - * Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ - def run(code: String, sets: Settings = new Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) - - if (sets.classpath.isDefault) { - sets.classpath.value = sys.props("java.class.path") - } - repl process sets - } - } - } - def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) -} diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 4dc399827ffed..88eb0ad1da3d7 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -22,8 +22,8 @@ import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.xbean.asm5._ -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6._ +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil @@ -187,7 +187,7 @@ class ExecutorClassLoader( } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassVisitor(ASM5, cv) { +extends ClassVisitor(ASM6, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index cc76a703bdf8f..e4ddcef9772e4 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -44,6 +44,7 @@ object Main extends Logging { var interp: SparkILoop = _ private var hasErrors = false + private var isShellSession = false private def scalaOptionError(msg: String): Unit = { hasErrors = true @@ -53,6 +54,7 @@ object Main extends Logging { } def main(args: Array[String]) { + isShellSession = true doMain(args, new SparkILoop) } @@ -79,44 +81,50 @@ object Main extends Logging { } def createSparkSession(): SparkSession = { - val execUri = System.getenv("SPARK_EXECUTOR_URI") - conf.setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) - if (execUri != null) { - conf.set("spark.executor.uri", execUri) - } - if (System.getenv("SPARK_HOME") != null) { - conf.setSparkHome(System.getenv("SPARK_HOME")) - } + try { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + conf.setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } - val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { - if (SparkSession.hiveClassesArePresent) { - // In the case that the property is not set at all, builder's config - // does not have this value set to 'hive' yet. The original default - // behavior is that when there are hive classes, we use hive catalog. - sparkSession = builder.enableHiveSupport().getOrCreate() - logInfo("Created Spark session with Hive support") + val builder = SparkSession.builder.config(conf) + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { + if (SparkSession.hiveClassesArePresent) { + // In the case that the property is not set at all, builder's config + // does not have this value set to 'hive' yet. The original default + // behavior is that when there are hive classes, we use hive catalog. + sparkSession = builder.enableHiveSupport().getOrCreate() + logInfo("Created Spark session with Hive support") + } else { + // Need to change it back to 'in-memory' if no hive classes are found + // in the case that the property is set to hive in spark-defaults.conf + builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } } else { - // Need to change it back to 'in-memory' if no hive classes are found - // in the case that the property is set to hive in spark-defaults.conf - builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + // In the case that the property is set but not to 'hive', the internal + // default is 'in-memory'. So the sparkSession will use in-memory catalog. sparkSession = builder.getOrCreate() logInfo("Created Spark session") } - } else { - // In the case that the property is set but not to 'hive', the internal - // default is 'in-memory'. So the sparkSession will use in-memory catalog. - sparkSession = builder.getOrCreate() - logInfo("Created Spark session") + sparkContext = sparkSession.sparkContext + sparkSession + } catch { + case e: Exception if isShellSession => + logError("Failed to initialize Spark session.", e) + sys.exit(1) } - sparkContext = sparkSession.sparkContext - sparkSession } } diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala new file mode 100644 index 0000000000000..aa9aa2793b8b3 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.BufferedReader + +// scalastyle:off println +import scala.Predef.{println => _, _} +// scalastyle:on println +import scala.concurrent.Future +import scala.reflect.classTag +import scala.reflect.io.File +import scala.tools.nsc.{GenericRunnerSettings, Properties} +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter.{isReplDebug, isReplPower, replProps} +import scala.tools.nsc.interpreter.{AbstractOrMissingHandler, ILoop, IMain, JPrintWriter} +import scala.tools.nsc.interpreter.{NamedParam, SimpleReader, SplashLoop, SplashReader} +import scala.tools.nsc.interpreter.StdReplTags.tagOfIMain +import scala.tools.nsc.util.stringFromStream +import scala.util.Properties.{javaVersion, javaVmName, versionNumberString, versionString} + +/** + * A Spark-specific interactive shell. + */ +class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) + extends ILoop(in0, out) { + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) + + /** + * TODO: Remove the following `override` when the support of Scala 2.11 is ended + * Scala 2.11 has a bug of finding imported types in class constructors, extends clause + * which is fixed in Scala 2.12 but never be back-ported into Scala 2.11.x. + * As a result, we copied the fixes into `SparkILoopInterpreter`. See SPARK-22393 for detail. + */ + override def createInterpreter(): Unit = { + if (isScala2_11) { + if (addedClasspath != "") { + settings.classpath append addedClasspath + } + // scalastyle:off classforname + // Have to use the default classloader to match the one used in + // `classOf[Settings]` and `classOf[JPrintWriter]`. + intp = Class.forName("org.apache.spark.repl.SparkILoopInterpreter") + .getDeclaredConstructor(Seq(classOf[Settings], classOf[JPrintWriter]): _*) + .newInstance(Seq(settings, out): _*) + .asInstanceOf[IMain] + // scalastyle:on classforname + } else { + super.createInterpreter() + } + } + + private val isScala2_11 = versionNumberString.startsWith("2.11") + + val initializationCommands: Seq[String] = Seq( + """ + @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { + org.apache.spark.repl.Main.sparkSession + } else { + org.apache.spark.repl.Main.createSparkSession() + } + @transient val sc = { + val _sc = spark.sparkContext + if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { + val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) + if (proxyUrl != null) { + println( + s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") + } else { + println(s"Spark Context Web UI is available at Spark Master Public URL") + } + } else { + _sc.uiWebUrl.foreach { + webUrl => println(s"Spark context Web UI available at ${webUrl}") + } + } + println("Spark context available as 'sc' " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + println("Spark session available as 'spark'.") + _sc + } + """, + "import org.apache.spark.SparkContext._", + "import spark.implicits._", + "import spark.sql", + "import org.apache.spark.sql.functions._" + ) + + def initializeSpark(): Unit = { + if (!intp.reporter.hasErrors) { + // `savingReplayStack` removes the commands from session history. + savingReplayStack { + initializationCommands.foreach(intp quietRun _) + } + } else { + throw new RuntimeException(s"Scala $versionString interpreter encountered " + + "errors during initialization") + } + } + + /** Print a welcome message */ + override def printWelcome() { + import org.apache.spark.SPARK_VERSION + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + /** Available commands */ + override def commands: List[LoopCommand] = standardCommands + + override def resetCommand(line: String): Unit = { + super.resetCommand(line) + initializeSpark() + echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") + } + + override def replay(): Unit = { + initializeSpark() + super.replay() + } + + /** + * TODO: Remove `runClosure` when the support of Scala 2.11 is ended + */ + private def runClosure(body: => Boolean): Boolean = { + if (isScala2_11) { + // In Scala 2.11, there is a bug that interpret could set the current thread's + // context classloader, but fails to reset it to its previous state when returning + // from that method. This is fixed in SI-8521 https://github.com/scala/scala/pull/5657 + // which is never back-ported into Scala 2.11.x. The following is a workaround fix. + val original = Thread.currentThread().getContextClassLoader + try { + body + } finally { + Thread.currentThread().setContextClassLoader(original) + } + } else { + body + } + } + + /** + * The following code is mostly a copy of `process` implementation in `ILoop.scala` in Scala + * + * In newer version of Scala, `printWelcome` is the first thing to be called. As a result, + * SparkUI URL information would be always shown after the welcome message. + * + * However, this is inconsistent compared with the existing version of Spark which will always + * show SparkUI URL first. + * + * The only way we can make it consistent will be duplicating the Scala code. + * + * We should remove this duplication once Scala provides a way to load our custom initialization + * code, and also customize the ordering of printing welcome message. + */ + override def process(settings: Settings): Boolean = runClosure { + + def newReader = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) + + /** Reader to use before interpreter is online. */ + def preLoop = { + val sr = SplashReader(newReader) { r => + in = r + in.postInit() + } + in = sr + SplashLoop(sr, prompt) + } + + /* Actions to cram in parallel while collecting first user input at prompt. + * Run with output muted both from ILoop and from the intp reporter. + */ + def loopPostInit(): Unit = mumly { + // Bind intp somewhere out of the regular namespace where + // we can get at it in generated code. + intp.quietBind(NamedParam[IMain]("$intp", intp)(tagOfIMain, classTag[IMain])) + + // Auto-run code via some setting. + ( replProps.replAutorunCode.option + flatMap (f => File(f).safeSlurp()) + foreach (intp quietRun _) + ) + // power mode setup + if (isReplPower) enablePowerMode(true) + initializeSpark() + loadInitFiles() + // SI-7418 Now, and only now, can we enable TAB completion. + in.postInit() + } + def loadInitFiles(): Unit = settings match { + case settings: GenericRunnerSettings => + for (f <- settings.loadfiles.value) { + loadCommand(f) + addReplay(s":load $f") + } + for (f <- settings.pastefiles.value) { + pasteCommand(f) + addReplay(s":paste $f") + } + case _ => + } + // wait until after startup to enable noisy settings + def withSuppressedSettings[A](body: => A): A = { + val ss = this.settings + import ss._ + val noisy = List(Xprint, Ytyperdebug) + val noisesome = noisy.exists(!_.isDefault) + val current = (Xprint.value, Ytyperdebug.value) + if (isReplDebug || !noisesome) body + else { + this.settings.Xprint.value = List.empty + this.settings.Ytyperdebug.value = false + try body + finally { + Xprint.value = current._1 + Ytyperdebug.value = current._2 + intp.global.printTypings = current._2 + } + } + } + def startup(): String = withSuppressedSettings { + // let them start typing + val splash = preLoop + + // while we go fire up the REPL + try { + // don't allow ancient sbt to hijack the reader + savingReader { + createInterpreter() + } + intp.initializeSynchronous() + + val field = classOf[ILoop].getDeclaredFields.filter(_.getName.contains("globalFuture")).head + field.setAccessible(true) + field.set(this, Future successful true) + + if (intp.reporter.hasErrors) { + echo("Interpreter encountered errors during initialization!") + null + } else { + loopPostInit() + printWelcome() + splash.start() + + val line = splash.line // what they typed in while they were waiting + if (line == null) { // they ^D + try out print Properties.shellInterruptedString + finally closeInterpreter() + } + line + } + } finally splash.stop() + } + + this.settings = settings + startup() match { + case null => false + case line => + try loop(line) match { + case LineResults.EOF => out print Properties.shellInterruptedString + case _ => + } + catch AbstractOrMissingHandler() + finally closeInterpreter() + true + } + } +} + +object SparkILoop { + + /** + * Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ + def run(code: String, sets: Settings = new Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) + + if (sets.classpath.isDefault) { + sets.classpath.value = sys.props("java.class.path") + } + repl process sets + } + } + } + def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) +} diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index cdd5cdd841740..4f3df729177fb 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,6 +21,7 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer +import scala.tools.nsc.interpreter.SimpleReader import org.apache.log4j.{Level, LogManager} @@ -84,6 +85,7 @@ class ReplSuite extends SparkFunSuite { settings = new scala.tools.nsc.Settings settings.usejavacp.value = true org.apache.spark.repl.Main.interp = this + in = SimpleReader() } val out = new StringWriter() diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a62f271273465..920f0f6ebf2c8 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -47,6 +47,12 @@ test
      + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + + io.fabric8 kubernetes-client @@ -77,6 +83,12 @@ + + com.squareup.okhttp3 + okhttp + 3.8.1 + + org.mockito mockito-core @@ -84,9 +96,9 @@ - com.squareup.okhttp3 - okhttp - 3.8.1 + org.jmock + jmock-junit4 + test diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 4086970ffb256..1b582fe53624a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -65,6 +65,7 @@ private[spark] object Config extends Logging { "spark.kubernetes.authenticate.driver" val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = "spark.kubernetes.authenticate.driver.mounted" + val KUBERNETES_AUTH_CLIENT_MODE_PREFIX = "spark.kubernetes.authenticate" val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" @@ -90,7 +91,7 @@ private[spark] object Config extends Logging { ConfigBuilder("spark.kubernetes.submitInDriver") .internal() .booleanConf - .createOptional + .createWithDefault(false) val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") @@ -117,6 +118,41 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("spark") + val KUBERNETES_PYSPARK_PY_FILES = + ConfigBuilder("spark.kubernetes.python.pyFiles") + .doc("The PyFiles that are distributed via client arguments") + .internal() + .stringConf + .createOptional + + val KUBERNETES_PYSPARK_MAIN_APP_RESOURCE = + ConfigBuilder("spark.kubernetes.python.mainAppResource") + .doc("The main app resource for pyspark jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_PYSPARK_APP_ARGS = + ConfigBuilder("spark.kubernetes.python.appArgs") + .doc("The app arguments for PySpark Jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_R_MAIN_APP_RESOURCE = + ConfigBuilder("spark.kubernetes.r.mainAppResource") + .doc("The main app resource for SparkR jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_R_APP_ARGS = + ConfigBuilder("spark.kubernetes.r.appArgs") + .doc("The app arguments for SparkR Jobs") + .internal() + .stringConf + .createOptional + val KUBERNETES_ALLOCATION_BATCH_SIZE = ConfigBuilder("spark.kubernetes.allocation.batch.size") .doc("Number of pods to launch at once in each round of executor allocation.") @@ -154,6 +190,41 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val KUBERNETES_EXECUTOR_API_POLLING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.apiPollingInterval") + .doc("Interval between polls against the Kubernetes API server to inspect the " + + "state of executors.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"API server polling interval must be a" + + " positive time value.") + .createWithDefaultString("30s") + + val KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.eventProcessingInterval") + .doc("Interval between successive inspection of executor events sent from the" + + " Kubernetes API.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"Event processing interval must be a positive" + + " time value.") + .createWithDefaultString("1s") + + val MEMORY_OVERHEAD_FACTOR = + ConfigBuilder("spark.kubernetes.memoryOverheadFactor") + .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + + "which in the case of JVM tasks will default to 0.10 and 0.40 for non-JVM jobs") + .doubleConf + .checkValue(mem_overhead => mem_overhead >= 0 && mem_overhead < 1, + "Ensure that memory overhead is a double between 0 --> 1.0") + .createWithDefault(0.1) + + val PYSPARK_MAJOR_PYTHON_VERSION = + ConfigBuilder("spark.kubernetes.pyspark.pythonVersion") + .doc("This sets the major Python version. Either 2 or 3. (Python2 or Python3)") + .stringConf + .checkValue(pv => List("2", "3").contains(pv), + "Ensure that major Python version is either Python2 or Python3") + .createWithDefault("2") + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -162,10 +233,24 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." + val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." + val KUBERNETES_DRIVER_VOLUMES_PREFIX = "spark.kubernetes.driver.volumes." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." + val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." + val KUBERNETES_EXECUTOR_VOLUMES_PREFIX = "spark.kubernetes.executor.volumes." + + val KUBERNETES_VOLUMES_HOSTPATH_TYPE = "hostPath" + val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" + val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" + val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" + val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" + val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" + val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" + val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 8da5f24044aad..8202d874a4626 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -25,9 +25,6 @@ private[spark] object Constants { val SPARK_POD_DRIVER_ROLE = "driver" val SPARK_POD_EXECUTOR_ROLE = "executor" - // Annotations - val SPARK_APP_NAME_ANNOTATION = "spark-app-name" - // Credentials secrets val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = "/mnt/secrets/spark-kubernetes-credentials" @@ -50,17 +47,15 @@ private[spark] object Constants { val DEFAULT_BLOCKMANAGER_PORT = 7079 val DRIVER_PORT_NAME = "driver-rpc-port" val BLOCK_MANAGER_PORT_NAME = "blockmanager" - val EXECUTOR_PORT_NAME = "executor" + val UI_PORT_NAME = "spark-ui" // Environment Variables - val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" val ENV_DRIVER_URL = "SPARK_DRIVER_URL" val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" @@ -71,9 +66,16 @@ private[spark] object Constants { val SPARK_CONF_FILE_NAME = "spark.properties" val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" + // BINDINGS + val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY" + val ENV_PYSPARK_FILES = "PYSPARK_FILES" + val ENV_PYSPARK_ARGS = "PYSPARK_APP_ARGS" + val ENV_PYSPARK_MAJOR_PYTHON_VERSION = "PYSPARK_MAJOR_PYTHON_VERSION" + val ENV_R_PRIMARY = "R_PRIMARY" + val ENV_R_ARGS = "R_APP_ARGS" + // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" - val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN_MIB = 384L } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 77b634ddfabcc..3aa35d419073f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -16,14 +16,17 @@ */ package org.apache.spark.deploy.k8s +import scala.collection.mutable + import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config.ConfigEntry + private[spark] sealed trait KubernetesRoleSpecificConf /* @@ -40,7 +43,7 @@ private[spark] case class KubernetesDriverSpecificConf( */ private[spark] case class KubernetesExecutorSpecificConf( executorId: String, - driverPod: Pod) + driverPod: Option[Pod]) extends KubernetesRoleSpecificConf /** @@ -54,7 +57,10 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleLabels: Map[String, String], roleAnnotations: Map[String, String], roleSecretNamesToMountPaths: Map[String, String], - roleEnvs: Map[String, String]) { + roleSecretEnvNamesToKeyRefs: Map[String, String], + roleEnvs: Map[String, String], + roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]], + sparkFiles: Seq[String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -63,10 +69,17 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( .map(str => str.split(",").toSeq) .getOrElse(Seq.empty[String]) - def sparkFiles(): Seq[String] = sparkConf - .getOption("spark.files") - .map(str => str.split(",").toSeq) - .getOrElse(Seq.empty[String]) + def pyFiles(): Option[String] = sparkConf + .get(KUBERNETES_PYSPARK_PY_FILES) + + def pySparkMainResource(): Option[String] = sparkConf + .get(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE) + + def pySparkPythonVersion(): String = sparkConf + .get(PYSPARK_MAJOR_PYTHON_VERSION) + + def sparkRMainResource(): Option[String] = sparkConf + .get(KUBERNETES_R_MAIN_APP_RESOURCE) def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) @@ -101,17 +114,33 @@ private[spark] object KubernetesConf { appId: String, mainAppResource: Option[MainAppResource], mainClass: String, - appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + appArgs: Array[String], + maybePyFiles: Option[String]): KubernetesConf[KubernetesDriverSpecificConf] = { val sparkConfWithMainAppJar = sparkConf.clone() + val additionalFiles = mutable.ArrayBuffer.empty[String] mainAppResource.foreach { - case JavaMainAppResource(res) => - val previousJars = sparkConf - .getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty) - if (!previousJars.contains(res)) { - sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) - } + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + // The function of this outer match is to account for multiple nonJVM + // bindings that will all have increased default MEMORY_OVERHEAD_FACTOR to 0.4 + case nonJVM: NonJVMResource => + nonJVM match { + case PythonMainAppResource(res) => + additionalFiles += res + maybePyFiles.foreach{maybePyFiles => + additionalFiles.appendAll(maybePyFiles.split(","))} + sparkConfWithMainAppJar.set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, res) + case RMainAppResource(res) => + additionalFiles += res + sparkConfWithMainAppJar.set(KUBERNETES_R_MAIN_APP_RESOURCE, res) + } + sparkConfWithMainAppJar.setIfMissing(MEMORY_OVERHEAD_FACTOR, 0.4) } val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( @@ -129,8 +158,21 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverSecretEnvNamesToKeyRefs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + val driverVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX).map(_.get) + // Also parse executor volumes in order to verify configuration + // before the driver pod is created + KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) + + val sparkFiles = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) ++ additionalFiles KubernetesConf( sparkConfWithMainAppJar, @@ -140,14 +182,17 @@ private[spark] object KubernetesConf { driverLabels, driverAnnotations, driverSecretNamesToMountPaths, - driverEnvs) + driverSecretEnvNamesToKeyRefs, + driverEnvs, + driverVolumes, + sparkFiles) } def createExecutorConf( sparkConf: SparkConf, executorId: String, appId: String, - driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + driverPod: Option[Pod]): KubernetesConf[KubernetesExecutorSpecificConf] = { val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) require( @@ -167,9 +212,13 @@ private[spark] object KubernetesConf { executorCustomLabels val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + val executorMountSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap + val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) KubernetesConf( sparkConf.clone(), @@ -178,7 +227,10 @@ private[spark] object KubernetesConf { appId, executorLabels, executorAnnotations, - executorSecrets, - executorEnv) + executorMountSecrets, + executorEnvSecrets, + executorEnv, + executorVolumes, + Seq.empty[String]) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index ee629068ad90d..588cd9d40f9a0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference - import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -52,7 +50,7 @@ private[spark] object KubernetesUtils { } } - private def resolveFileUri(uri: String): String = { + def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { @@ -60,4 +58,6 @@ private[spark] object KubernetesUtils { case _ => uri } } + + def parseMasterUrl(url: String): String = url.substring("k8s://".length) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala new file mode 100644 index 0000000000000..b1762d1efe2ea --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] sealed trait KubernetesVolumeSpecificConf + +private[spark] case class KubernetesHostPathVolumeConf( + hostPath: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesPVCVolumeConf( + claimName: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesEmptyDirVolumeConf( + medium: Option[String], + sizeLimit: Option[String]) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( + volumeName: String, + mountPath: String, + mountReadOnly: Boolean, + volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala new file mode 100644 index 0000000000000..713df5fffc3a2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.util.NoSuchElementException + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ + +private[spark] object KubernetesVolumeUtils { + /** + * Extract Spark volume configuration properties with a given name prefix. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing with volume name as key and spec as value + */ + def parseVolumesWithPrefix( + sparkConf: SparkConf, + prefix: String): Iterable[Try[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]] = { + val properties = sparkConf.getAllWithPrefix(prefix).toMap + + getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" + val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + + for { + path <- properties.getTry(pathKey) + volumeConf <- parseVolumeSpecificConf(properties, volumeType, volumeName) + } yield KubernetesVolumeSpec( + volumeName = volumeName, + mountPath = path, + mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), + volumeConf = volumeConf + ) + } + } + + /** + * Get unique pairs of volumeType and volumeName, + * assuming options are formatted in this way: + * `volumeType`.`volumeName`.`property` = `value` + * @param properties flat mapping of property names to values + * @return Set[(volumeType, volumeName)] + */ + private def getVolumeTypesAndNames( + properties: Map[String, String] + ): Set[(String, String)] = { + properties.keys.flatMap { k => + k.split('.').toList match { + case tpe :: name :: _ => Some((tpe, name)) + case _ => None + } + }.toSet + } + + private def parseVolumeSpecificConf( + options: Map[String, String], + volumeType: String, + volumeName: String): Try[KubernetesVolumeSpecificConf] = { + volumeType match { + case KUBERNETES_VOLUMES_HOSTPATH_TYPE => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" + for { + path <- options.getTry(pathKey) + } yield KubernetesHostPathVolumeConf(path) + + case KUBERNETES_VOLUMES_PVC_TYPE => + val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" + for { + claimName <- options.getTry(claimNameKey) + } yield KubernetesPVCVolumeConf(claimName) + + case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => + val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" + val sizeLimitKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY" + Success(KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey))) + + case _ => + Failure(new RuntimeException(s"Kubernetes Volume type `$volumeType` is not supported")) + } + } + + /** + * Convenience wrapper to accumulate key lookup errors + */ + implicit private class MapOps[A, B](m: Map[A, B]) { + def getTry(key: A): Try[B] = { + m + .get(key) + .fold[Try[B]](Failure(new NoSuchElementException(key.toString)))(Success(_)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 07bdccbe0479d..575bc54ffe2bb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -19,14 +19,15 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ import scala.collection.mutable -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ -import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.ui.SparkUI private[spark] class BasicDriverFeatureStep( conf: KubernetesConf[KubernetesDriverSpecificConf]) @@ -48,7 +49,8 @@ private[spark] class BasicDriverFeatureStep( private val driverMemoryMiB = conf.get(DRIVER_MEMORY) private val memoryOverheadMiB = conf .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + .getOrElse(math.max((conf.get(MEMORY_OVERHEAD_FACTOR) * driverMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configurePod(pod: SparkPod): SparkPod = { @@ -71,10 +73,31 @@ private[spark] class BasicDriverFeatureStep( ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) } + val driverPort = conf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) + val driverBlockManagerPort = conf.sparkConf.getInt( + DRIVER_BLOCK_MANAGER_PORT.key, + DEFAULT_BLOCKMANAGER_PORT + ) + val driverUIPort = SparkUI.getUIPort(conf.sparkConf) val driverContainer = new ContainerBuilder(pod.container) .withName(DRIVER_CONTAINER_NAME) .withImage(driverContainerImage) .withImagePullPolicy(conf.imagePullPolicy()) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withContainerPort(driverPort) + .withProtocol("TCP") + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withContainerPort(driverBlockManagerPort) + .withProtocol("TCP") + .endPort() + .addNewPort() + .withName(UI_PORT_NAME) + .withContainerPort(driverUIPort) + .withProtocol("TCP") + .endPort() .addAllToEnv(driverCustomEnvs.asJava) .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) @@ -88,13 +111,6 @@ private[spark] class BasicDriverFeatureStep( .addToRequests("memory", driverMemoryQuantity) .addToLimits("memory", driverMemoryQuantity) .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", conf.roleSpecificConf.mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - .addToArgs(conf.roleSpecificConf.appArgs: _*) .build() val driverPod = new PodBuilder(pod.pod) @@ -109,6 +125,7 @@ private[spark] class BasicDriverFeatureStep( .addToImagePullSecrets(conf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(driverPod, driverContainer) } @@ -122,7 +139,7 @@ private[spark] class BasicDriverFeatureStep( val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( conf.sparkJars()) val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( - conf.sparkFiles()) + conf.sparkFiles) if (resolvedSparkJars.nonEmpty) { additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index d22097587aafe..c37f713c56de1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -18,10 +18,10 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -54,7 +54,8 @@ private[spark] class BasicExecutorFeatureStep( private val memoryOverheadMiB = kubernetesConf .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + .getOrElse(math.max( + (kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) * executorMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB @@ -89,7 +90,9 @@ private[spark] class BasicExecutorFeatureStep( val executorExtraJavaOptionsEnv = kubernetesConf .get(EXECUTOR_JAVA_OPTIONS) .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) + val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, + kubernetesConf.roleSpecificConf.executorId) + val delimitedOpts = Utils.splitCommandString(subsOpts) delimitedOpts.zipWithIndex.map { case (opt, index) => new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() @@ -149,19 +152,20 @@ private[spark] class BasicExecutorFeatureStep( .build() }.getOrElse(executorContainer) val driverPod = kubernetesConf.roleSpecificConf.driverPod + val ownerReference = driverPod.map(pod => + new OwnerReferenceBuilder() + .withController(true) + .withApiVersion(pod.getApiVersion) + .withKind(pod.getKind) + .withName(pod.getMetadata.getName) + .withUid(pod.getMetadata.getUid) + .build()) val executorPod = new PodBuilder(pod.pod) .editOrNewMetadata() .withName(name) .withLabels(kubernetesConf.roleLabels.asJava) .withAnnotations(kubernetesConf.roleAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() + .addToOwnerReferences(ownerReference.toSeq: _*) .endMetadata() .editOrNewSpec() .withHostname(hostname) @@ -170,6 +174,7 @@ private[spark] class BasicExecutorFeatureStep( .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(executorPod, containerWithLimitCores) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala new file mode 100644 index 0000000000000..03ff7d48420ff --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class EnvSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedEnvSecrets = kubernetesConf + .roleSecretEnvNamesToKeyRefs + .map{ case (envName, keyRef) => + // Keyref parts + val keyRefParts = keyRef.split(":") + require(keyRefParts.size == 2, "SecretKeyRef must be in the form name:key.") + val name = keyRefParts(0) + val key = keyRefParts(1) + new EnvVarBuilder() + .withName(envName) + .withNewValueFrom() + .withNewSecretKeyRef() + .withKey(key) + .withName(name) + .endSecretKeyRef() + .endValueFrom() + .build() + } + + val containerWithEnvVars = new ContainerBuilder(pod.container) + .addAllToEnv(addedEnvSecrets.toSeq.asJava) + .build() + SparkPod(pod.pod, containerWithEnvVars) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala new file mode 100644 index 0000000000000..70b307303d149 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.nio.file.Paths +import java.util.UUID + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class LocalDirsFeatureStep( + conf: KubernetesConf[_ <: KubernetesRoleSpecificConf], + defaultLocalDir: String = s"/var/data/spark-${UUID.randomUUID}") + extends KubernetesFeatureConfigStep { + + // Cannot use Utils.getConfiguredLocalDirs because that will default to the Java system + // property - we want to instead default to mounting an emptydir volume that doesn't already + // exist in the image. + // We could make utils.getConfiguredLocalDirs opinionated about Kubernetes, as it is already + // a bit opinionated about YARN and Mesos. + private val resolvedLocalDirs = Option(conf.sparkConf.getenv("SPARK_LOCAL_DIRS")) + .orElse(conf.getOption("spark.local.dir")) + .getOrElse(defaultLocalDir) + .split(",") + + override def configurePod(pod: SparkPod): SparkPod = { + val localDirVolumes = resolvedLocalDirs + .zipWithIndex + .map { case (localDir, index) => + new VolumeBuilder() + .withName(s"spark-local-dir-${index + 1}") + .withNewEmptyDir() + .endEmptyDir() + .build() + } + val localDirVolumeMounts = localDirVolumes + .zip(resolvedLocalDirs) + .map { case (localDirVolume, localDirPath) => + new VolumeMountBuilder() + .withName(localDirVolume.getName) + .withMountPath(localDirPath) + .build() + } + val podWithLocalDirVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(localDirVolumes: _*) + .endSpec() + .build() + val containerWithLocalDirVolumeMounts = new ContainerBuilder(pod.container) + .addNewEnv() + .withName("SPARK_LOCAL_DIRS") + .withValue(resolvedLocalDirs.mkString(",")) + .endEnv() + .addToVolumeMounts(localDirVolumeMounts: _*) + .build() + SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala new file mode 100644 index 0000000000000..bb0e2b3128efd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.deploy.k8s._ + +private[spark] class MountVolumesFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + + override def configurePod(pod: SparkPod): SparkPod = { + val (volumeMounts, volumes) = constructVolumes(kubernetesConf.roleVolumes).unzip + + val podWithVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(volumes.toSeq: _*) + .endSpec() + .build() + + val containerWithVolumeMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(volumeMounts.toSeq: _*) + .build() + + SparkPod(podWithVolumes, containerWithVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def constructVolumes( + volumeSpecs: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]] + ): Iterable[(VolumeMount, Volume)] = { + volumeSpecs.map { spec => + val volumeMount = new VolumeMountBuilder() + .withMountPath(spec.mountPath) + .withReadOnly(spec.mountReadOnly) + .withName(spec.volumeName) + .build() + + val volumeBuilder = spec.volumeConf match { + case KubernetesHostPathVolumeConf(hostPath) => + new VolumeBuilder() + .withHostPath(new HostPathVolumeSource(hostPath)) + + case KubernetesPVCVolumeConf(claimName) => + new VolumeBuilder() + .withPersistentVolumeClaim( + new PersistentVolumeClaimVolumeSource(claimName, spec.mountReadOnly)) + + case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => + new VolumeBuilder() + .withEmptyDir( + new EmptyDirVolumeSource(medium.getOrElse(""), + new Quantity(sizeLimit.orNull))) + } + + val volume = volumeBuilder.withName(spec.volumeName).build() + + (volumeMount, volume) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala new file mode 100644 index 0000000000000..f52ec9fdc677e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants.SPARK_CONF_PATH +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep +import org.apache.spark.launcher.SparkLauncher + +private[spark] class JavaDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val withDriverArgs = new ContainerBuilder(pod.container) + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", kubernetesConf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(kubernetesConf.roleSpecificConf.appArgs: _*) + .build() + SparkPod(pod.pod, withDriverArgs) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala new file mode 100644 index 0000000000000..c20bcac1f8987 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep + +private[spark] class PythonDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val roleConf = kubernetesConf.roleSpecificConf + require(roleConf.mainAppResource.isDefined, "PySpark Main Resource must be defined") + val maybePythonArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( + pyArgs => + new EnvVarBuilder() + .withName(ENV_PYSPARK_ARGS) + .withValue(pyArgs.mkString(",")) + .build()) + val maybePythonFiles = kubernetesConf.pyFiles().map( + // Dilineation by ":" is to append the PySpark Files to the PYTHONPATH + // of the respective PySpark pod + pyFiles => + new EnvVarBuilder() + .withName(ENV_PYSPARK_FILES) + .withValue(KubernetesUtils.resolveFileUrisAndPath(pyFiles.split(",")) + .mkString(":")) + .build()) + val envSeq = + Seq(new EnvVarBuilder() + .withName(ENV_PYSPARK_PRIMARY) + .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.pySparkMainResource().get)) + .build(), + new EnvVarBuilder() + .withName(ENV_PYSPARK_MAJOR_PYTHON_VERSION) + .withValue(kubernetesConf.pySparkPythonVersion()) + .build()) + val pythonEnvs = envSeq ++ + maybePythonArgs.toSeq ++ + maybePythonFiles.toSeq + + val withPythonPrimaryContainer = new ContainerBuilder(pod.container) + .addAllToEnv(pythonEnvs.asJava) + .addToArgs("driver-py") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", roleConf.mainClass) + .build() + + SparkPod(pod.pod, withPythonPrimaryContainer) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala new file mode 100644 index 0000000000000..b33b86e02ea6f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep + +private[spark] class RDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val roleConf = kubernetesConf.roleSpecificConf + require(roleConf.mainAppResource.isDefined, "R Main Resource must be defined") + val maybeRArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( + rArgs => + new EnvVarBuilder() + .withName(ENV_R_ARGS) + .withValue(rArgs.mkString(",")) + .build()) + val envSeq = + Seq(new EnvVarBuilder() + .withName(ENV_R_PRIMARY) + .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.sparkRMainResource().get)) + .build()) + val rEnvs = envSeq ++ + maybeRArgs.toSeq + + val withRPrimaryContainer = new ContainerBuilder(pod.container) + .addAllToEnv(rEnvs.asJava) + .addToArgs("driver-r") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", roleConf.mainClass) + .build() + + SparkPod(pod.pod, withRPrimaryContainer) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index a97f5650fb869..986c950ab365a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -39,11 +39,13 @@ import org.apache.spark.util.Utils * @param mainAppResource the main application resource if any * @param mainClass the main class of the application to run * @param driverArgs arguments to the driver + * @param maybePyFiles additional Python files via --py-files */ private[spark] case class ClientArguments( mainAppResource: Option[MainAppResource], mainClass: String, - driverArgs: Array[String]) + driverArgs: Array[String], + maybePyFiles: Option[String]) private[spark] object ClientArguments { @@ -51,10 +53,17 @@ private[spark] object ClientArguments { var mainAppResource: Option[MainAppResource] = None var mainClass: Option[String] = None val driverArgs = mutable.ArrayBuffer.empty[String] + var maybePyFiles : Option[String] = None args.sliding(2, 2).toList.foreach { case Array("--primary-java-resource", primaryJavaResource: String) => mainAppResource = Some(JavaMainAppResource(primaryJavaResource)) + case Array("--primary-py-file", primaryPythonResource: String) => + mainAppResource = Some(PythonMainAppResource(primaryPythonResource)) + case Array("--primary-r-file", primaryRFile: String) => + mainAppResource = Some(RMainAppResource(primaryRFile)) + case Array("--other-py-files", pyFiles: String) => + maybePyFiles = Some(pyFiles) case Array("--main-class", clazz: String) => mainClass = Some(clazz) case Array("--arg", arg: String) => @@ -69,7 +78,8 @@ private[spark] object ClientArguments { ClientArguments( mainAppResource, mainClass.get, - driverArgs.toArray) + driverArgs.toArray, + maybePyFiles) } } @@ -206,6 +216,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + sparkConf.set(KUBERNETES_PYSPARK_PY_FILES, clientArguments.maybePyFiles.getOrElse("")) val kubernetesConf = KubernetesConf.createDriverConf( sparkConf, appName, @@ -213,12 +224,13 @@ private[spark] class KubernetesClientApplication extends SparkApplication { kubernetesAppId, clientArguments.mainAppResource, clientArguments.mainClass, - clientArguments.driverArgs) + clientArguments.driverArgs, + clientArguments.maybePyFiles) val builder = new KubernetesDriverBuilder val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. - val master = sparkConf.get("spark.master").substring("k8s://".length) + val master = KubernetesUtils.parseMasterUrl(sparkConf.get("spark.master")) val loggingInterval = if (waitForAppCompletion) Some(sparkConf.get(REPORT_INTERVAL)) else None val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index c7579ed8cb689..8f3f18ffadc3b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep, MountVolumesFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -29,17 +30,58 @@ private[spark] class KubernetesDriverBuilder( new DriverServiceFeatureStep(_), provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_)) { + new MountSecretsFeatureStep(_), + provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_), + providePythonStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => PythonDriverFeatureStep) = + new PythonDriverFeatureStep(_), + provideRStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => RDriverFeatureStep) = + new RDriverFeatureStep(_), + provideJavaStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => JavaDriverFeatureStep) = + new JavaDriverFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { val baseFeatures = Seq( provideBasicStep(kubernetesConf), provideCredentialsStep(kubernetesConf), - provideServiceStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) - } else baseFeatures + provideServiceStep(kubernetesConf), + provideLocalDirsStep(kubernetesConf)) + + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val envSecretFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil + + val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { + case JavaMainAppResource(_) => + provideJavaStep(kubernetesConf) + case PythonMainAppResource(_) => + providePythonStep(kubernetesConf) + case RMainAppResource(_) => + provideRStep(kubernetesConf)} + .getOrElse(provideJavaStep(kubernetesConf)) + + val allFeatures = (baseFeatures :+ bindingsStep) ++ + secretFeature ++ envSecretFeature ++ volumesFeature var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala index cca9f4627a1f6..dd5a4549743df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala @@ -18,4 +18,12 @@ package org.apache.spark.deploy.k8s.submit private[spark] sealed trait MainAppResource +private[spark] sealed trait NonJVMResource + private[spark] case class JavaMainAppResource(primaryResource: String) extends MainAppResource + +private[spark] case class PythonMainAppResource(primaryResource: String) + extends MainAppResource with NonJVMResource + +private[spark] case class RMainAppResource(primaryResource: String) + extends MainAppResource with NonJVMResource diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala new file mode 100644 index 0000000000000..83daddf714489 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +sealed trait ExecutorPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends ExecutorPodState + +case class PodPending(pod: Pod) extends ExecutorPodState + +sealed trait FinalPodState extends ExecutorPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends ExecutorPodState diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala new file mode 100644 index 0000000000000..77bb9c3fcc9f4 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, Utils} + +private[spark] class ExecutorPodsAllocator( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + clock: Clock) extends Logging { + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + + private val totalExpectedExecutors = new AtomicInteger(0) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + + private val namespace = conf.get(KUBERNETES_NAMESPACE) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + + private val driverPod = kubernetesDriverPodName + .map(name => Option(kubernetesClient.pods() + .withName(name) + .get()) + .getOrElse(throw new SparkException( + s"No pod was found named $kubernetesDriverPodName in the cluster in the " + + s"namespace $namespace (this was supposed to be the driver pod.)."))) + + // Executor IDs that have been requested from Kubernetes but have not been detected in any + // snapshot yet. Mapped to the timestamp when they were created. + private val newlyCreatedExecutors = mutable.Map.empty[Long, Long] + + def start(applicationId: String): Unit = { + snapshotsStore.addSubscriber(podAllocationDelay) { + onNewSnapshots(applicationId, _) + } + } + + def setTotalExpectedExecutors(total: Int): Unit = totalExpectedExecutors.set(total) + + private def onNewSnapshots(applicationId: String, snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + newlyCreatedExecutors --= snapshots.flatMap(_.executorPods.keys) + // For all executors we've created against the API but have not seen in a snapshot + // yet - check the current time. If the current time has exceeded some threshold, + // assume that the pod was either never created (the API server never properly + // handled the creation request), or the API server created the pod but we missed + // both the creation and deletion events. In either case, delete the missing pod + // if possible, and mark such a pod to be rescheduled below. + newlyCreatedExecutors.foreach { case (execId, timeCreated) => + val currentTime = clock.getTimeMillis() + if (currentTime - timeCreated > podCreationTimeout) { + logWarning(s"Executor with id $execId was not detected in the Kubernetes" + + s" cluster after $podCreationTimeout milliseconds despite the fact that a" + + " previous allocation attempt tried to create it. The executor may have been" + + " deleted but the application missed the deletion event.") + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) + .delete() + } + newlyCreatedExecutors -= execId + } else { + logDebug(s"Executor with id $execId was not found in the Kubernetes cluster since it" + + s" was created ${currentTime - timeCreated} milliseconds ago.") + } + } + + if (snapshots.nonEmpty) { + // Only need to examine the cluster as of the latest snapshot, the "current" state, to see if + // we need to allocate more executors or not. + val latestSnapshot = snapshots.last + val currentRunningExecutors = latestSnapshot.executorPods.values.count { + case PodRunning(_) => true + case _ => false + } + val currentPendingExecutors = latestSnapshot.executorPods.values.count { + case PodPending(_) => true + case _ => false + } + val currentTotalExpectedExecutors = totalExpectedExecutors.get + logDebug(s"Currently have $currentRunningExecutors running executors and" + + s" $currentPendingExecutors pending executors. $newlyCreatedExecutors executors" + + s" have been requested but are pending appearance in the cluster.") + if (newlyCreatedExecutors.isEmpty + && currentPendingExecutors == 0 + && currentRunningExecutors < currentTotalExpectedExecutors) { + val numExecutorsToAllocate = math.min( + currentTotalExpectedExecutors - currentRunningExecutors, podAllocationSize) + logInfo(s"Going to request $numExecutorsToAllocate executors from Kubernetes.") + for ( _ <- 0 until numExecutorsToAllocate) { + val newExecutorId = EXECUTOR_ID_COUNTER.incrementAndGet() + val executorConf = KubernetesConf.createExecutorConf( + conf, + newExecutorId.toString, + applicationId, + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + kubernetesClient.pods().create(podWithAttachedContainer) + newlyCreatedExecutors(newExecutorId) = clock.getTimeMillis() + logDebug(s"Requested executor with id $newExecutorId from Kubernetes.") + } + } else if (currentRunningExecutors >= currentTotalExpectedExecutors) { + // TODO handle edge cases if we end up with more running executors than expected. + logDebug("Current number of running executors is equal to the number of requested" + + " executors. Not scaling up further.") + } else if (newlyCreatedExecutors.nonEmpty || currentPendingExecutors != 0) { + logDebug(s"Still waiting for ${newlyCreatedExecutors.size + currentPendingExecutors}" + + s" executors to begin running before requesting for more executors. # of executors in" + + s" pending status in the cluster: $currentPendingExecutors. # of executors that we have" + + s" created but we have not observed as being present in the cluster yet:" + + s" ${newlyCreatedExecutors.size}.") + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala new file mode 100644 index 0000000000000..b28d93990313e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.Cache +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsLifecycleManager( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + // Use a best-effort to track which executors have been removed already. It's not generally + // job-breaking if we remove executors more than once but it's ideal if we make an attempt + // to avoid doing so. Expire cache entries so that this data structure doesn't grow beyond + // bounds. + removedExecutorsCache: Cache[java.lang.Long, java.lang.Long]) extends Logging { + + import ExecutorPodsLifecycleManager._ + + private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) + + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { + snapshotsStore.addSubscriber(eventProcessingInterval) { + onNewSnapshots(schedulerBackend, _) + } + } + + private def onNewSnapshots( + schedulerBackend: KubernetesClusterSchedulerBackend, + snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + val execIdsRemovedInThisRound = mutable.HashSet.empty[Long] + snapshots.foreach { snapshot => + snapshot.executorPods.foreach { case (execId, state) => + state match { + case deleted@PodDeleted(_) => + logDebug(s"Snapshot reported deleted executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + removeExecutorFromSpark(schedulerBackend, deleted, execId) + execIdsRemovedInThisRound += execId + case failed@PodFailed(_) => + logDebug(s"Snapshot reported failed executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + onFinalNonDeletedState(failed, execId, schedulerBackend, execIdsRemovedInThisRound) + case succeeded@PodSucceeded(_) => + logDebug(s"Snapshot reported succeeded executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}. Note that succeeded executors are" + + s" unusual unless Spark specifically informed the executor to exit.") + onFinalNonDeletedState(succeeded, execId, schedulerBackend, execIdsRemovedInThisRound) + case _ => + } + } + } + + // Reconcile the case where Spark claims to know about an executor but the corresponding pod + // is missing from the cluster. This would occur if we miss a deletion event and the pod + // transitions immediately from running io absent. We only need to check against the latest + // snapshot for this, and we don't do this for executors in the deleted executors cache or + // that we just removed in this round. + if (snapshots.nonEmpty) { + val latestSnapshot = snapshots.last + (schedulerBackend.getExecutorIds().map(_.toLong).toSet + -- latestSnapshot.executorPods.keySet + -- execIdsRemovedInThisRound).foreach { missingExecutorId => + if (removedExecutorsCache.getIfPresent(missingExecutorId) == null) { + val exitReasonMessage = s"The executor with ID $missingExecutorId was not found in the" + + s" cluster but we didn't get a reason why. Marking the executor as failed. The" + + s" executor may have been deleted but the driver missed the deletion event." + logDebug(exitReasonMessage) + val exitReason = ExecutorExited( + UNKNOWN_EXIT_CODE, + exitCausedByApp = false, + exitReasonMessage) + schedulerBackend.doRemoveExecutor(missingExecutorId.toString, exitReason) + execIdsRemovedInThisRound += missingExecutorId + } + } + } + logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + + s" from Spark that were either found to be deleted or non-existent in the cluster.") + } + + private def onFinalNonDeletedState( + podState: FinalPodState, + execId: Long, + schedulerBackend: KubernetesClusterSchedulerBackend, + execIdsRemovedInRound: mutable.Set[Long]): Unit = { + removeExecutorFromK8s(podState.pod) + removeExecutorFromSpark(schedulerBackend, podState, execId) + execIdsRemovedInRound += execId + } + + private def removeExecutorFromK8s(updatedPod: Pod): Unit = { + // If deletion failed on a previous try, we can try again if resync informs us the pod + // is still around. + // Delete as best attempt - duplicate deletes will throw an exception but the end state + // of getting rid of the pod is what matters. + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withName(updatedPod.getMetadata.getName) + .delete() + } + } + + private def removeExecutorFromSpark( + schedulerBackend: KubernetesClusterSchedulerBackend, + podState: FinalPodState, + execId: Long): Unit = { + if (removedExecutorsCache.getIfPresent(execId) == null) { + removedExecutorsCache.put(execId, execId) + val exitReason = findExitReason(podState, execId) + schedulerBackend.doRemoveExecutor(execId.toString, exitReason) + } + } + + private def findExitReason(podState: FinalPodState, execId: Long): ExecutorExited = { + val exitCode = findExitCode(podState) + val (exitCausedByApp, exitMessage) = podState match { + case PodDeleted(_) => + (false, s"The executor with id $execId was deleted by a user or the framework.") + case _ => + val msg = exitReasonMessage(podState, execId, exitCode) + (true, msg) + } + ExecutorExited(exitCode, exitCausedByApp, exitMessage) + } + + private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { + val pod = podState.pod + s""" + |The executor with id $execId exited with exit code $exitCode. + |The API gave the following brief reason: ${pod.getStatus.getReason} + |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def findExitCode(podState: FinalPodState): Int = { + podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => + containerStatus.getState.getTerminated != null + }.map { terminatedContainer => + terminatedContainer.getState.getTerminated.getExitCode.toInt + }.getOrElse(UNKNOWN_EXIT_CODE) + } +} + +private object ExecutorPodsLifecycleManager { + val UNKNOWN_EXIT_CODE = -1 +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala new file mode 100644 index 0000000000000..e77e604d00e0f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{Future, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils + +private[spark] class ExecutorPodsPollingSnapshotSource( + conf: SparkConf, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + pollingExecutor: ScheduledExecutorService) extends Logging { + + private val pollingInterval = conf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + private var pollingFuture: Future[_] = _ + + def start(applicationId: String): Unit = { + require(pollingFuture == null, "Cannot start polling more than once.") + logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") + pollingFuture = pollingExecutor.scheduleWithFixedDelay( + new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) + } + + def stop(): Unit = { + if (pollingFuture != null) { + pollingFuture.cancel(true) + pollingFuture = null + } + ThreadUtils.shutdown(pollingExecutor) + } + + private class PollRunnable(applicationId: String) extends Runnable { + override def run(): Unit = { + logDebug(s"Resynchronizing full executor pod state from Kubernetes.") + snapshotsStore.replaceSnapshot(kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .list() + .getItems + .asScala) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala new file mode 100644 index 0000000000000..26be918043412 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging + +/** + * An immutable view of the current executor pods that are running in the cluster. + */ +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { + + import ExecutorPodsSnapshot._ + + def withUpdate(updatedPod: Pod): ExecutorPodsSnapshot = { + val newExecutorPods = executorPods ++ toStatesByExecutorId(Seq(updatedPod)) + new ExecutorPodsSnapshot(newExecutorPods) + } +} + +object ExecutorPodsSnapshot extends Logging { + + def apply(executorPods: Seq[Pod]): ExecutorPodsSnapshot = { + ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) + } + + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + executorPods.map { pod => + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + }.toMap + } + + private def toState(pod: Pod): ExecutorPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..dd264332cf9e8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +private[spark] trait ExecutorPodsSnapshotsStore { + + def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) + + def stop(): Unit + + def updatePod(updatedPod: Pod): Unit + + def replaceSnapshot(newSnapshot: Seq[Pod]): Unit +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala new file mode 100644 index 0000000000000..5583b4617eeb2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent._ + +import io.fabric8.kubernetes.api.model.Pod +import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Controls the propagation of the Spark application's executor pods state to subscribers that + * react to that state. + *
      + * Roughly follows a producer-consumer model. Producers report states of executor pods, and these + * states are then published to consumers that can perform any actions in response to these states. + *
      + * Producers push updates in one of two ways. An incremental update sent by updatePod() represents + * a known new state of a single executor pod. A full sync sent by replaceSnapshot() indicates that + * the passed pods are all of the most up to date states of all executor pods for the application. + * The combination of the states of all executor pods for the application is collectively known as + * a snapshot. The store keeps track of the most up to date snapshot, and applies updates to that + * most recent snapshot - either by incrementally updating the snapshot with a single new pod state, + * or by replacing the snapshot entirely on a full sync. + *
      + * Consumers, or subscribers, register that they want to be informed about all snapshots of the + * executor pods. Every time the store replaces its most up to date snapshot from either an + * incremental update or a full sync, the most recent snapshot after the update is posted to the + * subscriber's buffer. Subscribers receive blocks of snapshots produced by the producers in + * time-windowed chunks. Each subscriber can choose to receive their snapshot chunks at different + * time intervals. + */ +private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: ScheduledExecutorService) + extends ExecutorPodsSnapshotsStore { + + private val SNAPSHOT_LOCK = new Object() + + private val subscribers = mutable.Buffer.empty[SnapshotsSubscriber] + private val pollingTasks = mutable.Buffer.empty[Future[_]] + + @GuardedBy("SNAPSHOT_LOCK") + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber( + processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + val newSubscriber = SnapshotsSubscriber( + new LinkedBlockingQueue[ExecutorPodsSnapshot](), onNewSnapshots) + SNAPSHOT_LOCK.synchronized { + newSubscriber.snapshotsBuffer.add(currentSnapshot) + } + subscribers += newSubscriber + pollingTasks += subscribersExecutor.scheduleWithFixedDelay( + toRunnable(() => callSubscriber(newSubscriber)), + 0L, + processBatchIntervalMillis, + TimeUnit.MILLISECONDS) + } + + override def stop(): Unit = { + pollingTasks.foreach(_.cancel(true)) + ThreadUtils.shutdown(subscribersExecutor) + } + + override def updatePod(updatedPod: Pod): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + addCurrentSnapshotToSubscribers() + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + addCurrentSnapshotToSubscribers() + } + + private def addCurrentSnapshotToSubscribers(): Unit = { + subscribers.foreach { subscriber => + subscriber.snapshotsBuffer.add(currentSnapshot) + } + } + + private def callSubscriber(subscriber: SnapshotsSubscriber): Unit = { + Utils.tryLogNonFatalError { + val currentSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot].asJava + subscriber.snapshotsBuffer.drainTo(currentSnapshots) + subscriber.onNewSnapshots(currentSnapshots.asScala) + } + } + + private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable { + override def run(): Unit = runnable() + } + + private case class SnapshotsSubscriber( + snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot], + onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala new file mode 100644 index 0000000000000..a6749a644e00c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsWatchSnapshotSource( + snapshotsStore: ExecutorPodsSnapshotsStore, + kubernetesClient: KubernetesClient) extends Logging { + + private var watchConnection: Closeable = _ + + def start(applicationId: String): Unit = { + require(watchConnection == null, "Cannot start the watcher twice.") + logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + + s" $SPARK_ROLE_LABEL=$SPARK_POD_EXECUTOR_ROLE.") + watchConnection = kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .watch(new ExecutorPodsWatcher()) + } + + def stop(): Unit = { + if (watchConnection != null) { + Utils.tryLogNonFatalError { + watchConnection.close() + } + watchConnection = null + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + logDebug(s"Received executor pod update for pod named $podName, action $action") + snapshotsStore.updatePod(pod) + } + + override def onClose(e: KubernetesClientException): Unit = { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 0ea80dfbc0d97..9999c62c878df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -17,28 +17,24 @@ package org.apache.spark.scheduler.cluster.k8s import java.io.File +import java.util.concurrent.TimeUnit +import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SystemClock, ThreadUtils} private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { - if (masterURL.startsWith("k8s") && - sc.deployMode == "client" && - !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) { - throw new SparkException("Client mode is currently not supported for Kubernetes.") - } - new TaskSchedulerImpl(sc) } @@ -46,27 +42,72 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val wasSparkSubmittedInClusterMode = sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) + val (authConfPrefix, + apiServerUri, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { + require(sc.conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, + "If the application is deployed using spark-submit in cluster mode, the driver pod name " + + "must be provided.") + (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + KUBERNETES_MASTER_INTERNAL_URL, + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + } else { + (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, + KubernetesUtils.parseMasterUrl(masterURL), + None, + None) + } + val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( - KUBERNETES_MASTER_INTERNAL_URL, + apiServerUri, Some(sc.conf.get(KUBERNETES_NAMESPACE)), - KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + authConfPrefix, sc.conf, - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + defaultServiceAccountToken, + defaultServiceAccountCaCrt) - val allocatorExecutor = ThreadUtils - .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") + + val subscribersExecutor = ThreadUtils + .newDaemonThreadPoolScheduledExecutor( + "kubernetes-executor-snapshots-subscribers", 2) + val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) + val removedExecutorsCache = CacheBuilder.newBuilder() + .expireAfterWrite(3, TimeUnit.MINUTES) + .build[java.lang.Long, java.lang.Long]() + val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( + sc.conf, + new KubernetesExecutorBuilder(), + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + + val executorPodsAllocator = new ExecutorPodsAllocator( + sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) + + val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource( + snapshotsStore, + kubernetesClient) + + val eventsPollingExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kubernetes-executor-pod-polling-sync") + val podsPollingEventSource = new ExecutorPodsPollingSnapshotSource( + sc.conf, kubernetesClient, snapshotsStore, eventsPollingExecutor) + new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - new KubernetesExecutorBuilder, kubernetesClient, - allocatorExecutor, - requestExecutorsService) + requestExecutorsService, + snapshotsStore, + executorPodsAllocator, + executorPodsLifecycleEventHandler, + podsWatchEventSource, + podsPollingEventSource) } override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index d86664c81071b..fa6dc2c479bbf 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -16,60 +16,32 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.Closeable -import java.net.InetAddress -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import scala.collection.JavaConverters._ -import scala.collection.mutable +import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesConf -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} -import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, - allocatorExecutor: ScheduledExecutorService, - requestExecutorsService: ExecutorService) + requestExecutorsService: ExecutorService, + snapshotsStore: ExecutorPodsSnapshotsStore, + podAllocator: ExecutorPodsAllocator, + lifecycleEventHandler: ExecutorPodsLifecycleManager, + watchEvents: ExecutorPodsWatchSnapshotSource, + pollEvents: ExecutorPodsPollingSnapshotSource) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - import KubernetesClusterSchedulerBackend._ - - private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) - private val RUNNING_EXECUTOR_PODS_LOCK = new Object - @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") - private val runningExecutorsToPods = new mutable.HashMap[String, Pod] - private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() - private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() - private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() - - private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) - - private val kubernetesDriverPodName = conf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( requestExecutorsService) - private val driverPod = kubernetesClient.pods() - .inNamespace(kubernetesNamespace) - .withName(kubernetesDriverPodName) - .get() - protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { 0.8 @@ -77,372 +49,93 @@ private[spark] class KubernetesClusterSchedulerBackend( super.minRegisteredRatio } - private val executorWatchResource = new AtomicReference[Closeable] - private val totalExpectedExecutors = new AtomicInteger(0) - - private val driverUrl = RpcEndpointAddress( - conf.get("spark.driver.host"), - conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString - private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) - private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) - - private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) - - private val executorLostReasonCheckMaxAttempts = conf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - private val allocatorRunnable = new Runnable { - - // Maintains a map of executor id to count of checks performed to learn the loss reason - // for an executor. - private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] - - override def run(): Unit = { - handleDisconnectedExecutors() - - val executorsToAllocate = mutable.Map[String, Pod]() - val currentTotalRegisteredExecutors = totalRegisteredExecutors.get - val currentTotalExpectedExecutors = totalExpectedExecutors.get - val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { - logDebug("Waiting for pending executors before scaling") - } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { - logDebug("Maximum allowed executor limit reached. Not scaling up further.") - } else { - for (_ <- 0 until math.min( - currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { - val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorConf = KubernetesConf.createExecutorConf( - conf, - executorId, - applicationId(), - driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf) - val podWithAttachedContainer = new PodBuilder(executorPod.pod) - .editOrNewSpec() - .addToContainers(executorPod.container) - .endSpec() - .build() - - executorsToAllocate(executorId) = podWithAttachedContainer - logInfo( - s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") - } - } - } - - val allocatedExecutors = executorsToAllocate.mapValues { pod => - Utils.tryLog { - kubernetesClient.pods().create(pod) - } - } - - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - allocatedExecutors.map { - case (executorId, attemptedAllocatedExecutor) => - attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => - runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) - } - } - } - } - - def handleDisconnectedExecutors(): Unit = { - // For each disconnected executor, synchronize with the loss reasons that may have been found - // by the executor pod watcher. If the loss reason was discovered by the watcher, - // inform the parent class with removeExecutor. - disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { - case (executorId, executorPod) => - val knownExitReason = Option(podsWithKnownExitReasons.remove( - executorPod.getMetadata.getName)) - knownExitReason.fold { - removeExecutorOrIncrementLossReasonCheckCount(executorId) - } { executorExited => - logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) - removeExecutor(executorId, executorExited) - // We don't delete the pod running the executor that has an exit condition caused by - // the application from the Kubernetes API server. This allows users to debug later on - // through commands such as "kubectl logs " and - // "kubectl describe pod ". Note that exited containers have terminated and - // therefore won't take CPU and memory resources. - // Otherwise, the executor pod is marked to be deleted from the API server. - if (executorExited.exitCausedByApp) { - logInfo(s"Executor $executorId exited because of the application.") - deleteExecutorFromDataStructures(executorId) - } else { - logInfo(s"Executor $executorId failed because of a framework error.") - deleteExecutorFromClusterAndDataStructures(executorId) - } - } - } - } - - def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { - val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) - if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { - removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) - deleteExecutorFromClusterAndDataStructures(executorId) - } else { - executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) - } - } - - def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { - deleteExecutorFromDataStructures(executorId).foreach { pod => - kubernetesClient.pods().delete(pod) - } - } - - def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { - disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) - executorReasonCheckAttemptCounts -= executorId - podsWithKnownExitReasons.remove(executorId) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.remove(executorId).orElse { - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler + private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + removeExecutor(executorId, reason) } override def start(): Unit = { super.start() - executorWatchResource.set( - kubernetesClient - .pods() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .watch(new ExecutorPodsWatcher())) - - allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS) - if (!Utils.isDynamicAllocationEnabled(conf)) { - doRequestTotalExecutors(initialExecutors) + podAllocator.setTotalExpectedExecutors(initialExecutors) } + lifecycleEventHandler.start(this) + podAllocator.start(applicationId()) + watchEvents.start(applicationId()) + pollEvents.start(applicationId()) } override def stop(): Unit = { - // stop allocation of new resources and caches. - allocatorExecutor.shutdown() - allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) - - // send stop message to executors so they shut down cleanly super.stop() - try { - val resource = executorWatchResource.getAndSet(null) - if (resource != null) { - resource.close() - } - } catch { - case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + Utils.tryLogNonFatalError { + snapshotsStore.stop() } - // then delete the executor pods Utils.tryLogNonFatalError { - deleteExecutorPodsOnStop() - executorPodsByIPs.clear() + watchEvents.stop() } + Utils.tryLogNonFatalError { - logInfo("Closing kubernetes client") - kubernetesClient.close() + pollEvents.stop() } - } - /** - * @return A map of K8s cluster nodes to the number of tasks that could benefit from data - * locality if an executor launches on the cluster node. - */ - private def getNodesWithLocalTaskCounts() : Map[String, Int] = { - val nodeToLocalTaskCount = synchronized { - mutable.Map[String, Int]() ++ hostToLocalTaskCount + Utils.tryLogNonFatalError { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .delete() } - for (pod <- executorPodsByIPs.values().asScala) { - // Remove cluster nodes that are running our executors already. - // TODO: This prefers spreading out executors across nodes. In case users want - // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut - // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html - nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || - nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || - nodeToLocalTaskCount.remove( - InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + Utils.tryLogNonFatalError { + ThreadUtils.shutdown(requestExecutorsService) + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() } - nodeToLocalTaskCount.toMap[String, Int] } override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { - totalExpectedExecutors.set(requestedTotal) + // TODO when we support dynamic allocation, the pod allocator should be told to process the + // current snapshot in order to decrease/increase the number of executors accordingly. + podAllocator.setTotalExpectedExecutors(requestedTotal) true } - override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - executorIds.flatMap { executorId => - runningExecutorsToPods.remove(executorId) match { - case Some(pod) => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - Some(pod) - - case None => - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - - kubernetesClient.pods().delete(podsToDelete: _*) - true + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio } - private def deleteExecutorPodsOnStop(): Unit = { - val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) - runningExecutorsToPods.clear() - runningExecutorPodsCopy - } - kubernetesClient.pods().delete(executorPodsToDelete: _*) + override def getExecutorIds(): Seq[String] = synchronized { + super.getExecutorIds() } - private class ExecutorPodsWatcher extends Watcher[Pod] { - - private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 - - override def eventReceived(action: Action, pod: Pod): Unit = { - val podName = pod.getMetadata.getName - val podIP = pod.getStatus.getPodIP - - action match { - case Action.MODIFIED if (pod.getStatus.getPhase == "Running" - && pod.getMetadata.getDeletionTimestamp == null) => - val clusterNodeName = pod.getSpec.getNodeName - logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") - executorPodsByIPs.put(podIP, pod) - - case Action.DELETED | Action.ERROR => - val executorId = getExecutorId(pod) - logDebug(s"Executor pod $podName at IP $podIP was at $action.") - if (podIP != null) { - executorPodsByIPs.remove(podIP) - } - - val executorExitReason = if (action == Action.ERROR) { - logWarning(s"Received error event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnError(pod) - } else if (action == Action.DELETED) { - logWarning(s"Received delete event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnDelete(pod) - } else { - throw new IllegalStateException( - s"Unknown action that should only be DELETED or ERROR: $action") - } - podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) - - if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { - log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + - s"watch received an event of type $action for this executor. The executor may " + - "have failed to start in the first place and never registered with the driver.") - } - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - - case _ => logDebug(s"Received event of executor pod $podName: " + action) - } - } - - override def onClose(cause: KubernetesClientException): Unit = { - logDebug("Executor pod watch closed.", cause) - } - - private def getExecutorExitStatus(pod: Pod): Int = { - val containerStatuses = pod.getStatus.getContainerStatuses - if (!containerStatuses.isEmpty) { - // we assume the first container represents the pod status. This assumption may not hold - // true in the future. Revisit this if side-car containers start running inside executor - // pods. - getExecutorExitStatus(containerStatuses.get(0)) - } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS - } - - private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { - Option(containerStatus.getState).map { containerState => - Option(containerState.getTerminated).map { containerStateTerminated => - containerStateTerminated.getExitCode.intValue() - }.getOrElse(UNKNOWN_EXIT_CODE) - }.getOrElse(UNKNOWN_EXIT_CODE) - } - - private def isPodAlreadyReleased(pod: Pod): Boolean = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - !runningExecutorsToPods.contains(executorId) - } - } - - private def executorExitReasonOnError(pod: Pod): ExecutorExited = { - val containerExitStatus = getExecutorExitStatus(pod) - // container was probably actively killed by the driver. - if (isPodAlreadyReleased(pod)) { - ExecutorExited(containerExitStatus, exitCausedByApp = false, - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + - "request.") - } else { - val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + - s"exited with exit status code $containerExitStatus." - ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) - } - } - - private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { - val exitMessage = if (isPodAlreadyReleased(pod)) { - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." - } else { - s"Pod ${pod.getMetadata.getName} deleted or lost." - } - ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) - } - - private def getExecutorId(pod: Pod): String = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - require(executorId != null, "Unexpected pod metadata; expected all executor pods " + - s"to have label $SPARK_EXECUTOR_ID_LABEL.") - executorId - } + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) + .delete() + // Don't do anything else - let event handling from the Kubernetes API do the Spark changes } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { new KubernetesDriverEndpoint(rpcEnv, properties) } - private class KubernetesDriverEndpoint( - rpcEnv: RpcEnv, - sparkProperties: Seq[(String, String)]) + private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { - addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.get(executorId).foreach { pod => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - } - } - } - } + // Don't do anything besides disabling the executor - allow the Kubernetes API events to + // drive the rest of the lifecycle decisions + // TODO what if we disconnect from a networking issue? Probably want to mark the executor + // to be deleted eventually. + addressToExecutorId.get(rpcAddress).foreach(disableExecutor) } } -} -private object KubernetesClusterSchedulerBackend { - private val UNKNOWN_EXIT_CODE = -1 } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 22568fe7ea3be..364b6fb367722 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,21 +17,42 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + provideBasicStep: (KubernetesConf [KubernetesExecutorSpecificConf]) + => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), - provideSecretsStep: - (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = - new MountSecretsFeatureStep(_)) { + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => MountSecretsFeatureStep = + new MountSecretsFeatureStep(_), + provideEnvSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq(provideBasicStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) - } else baseFeatures + + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val secretEnvFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil + + val allFeatures = baseFeatures ++ secretFeature ++ secretEnvFeature ++ volumesFeature + var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { executorPod = feature.configurePod(executorPod) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala new file mode 100644 index 0000000000000..527fc6b0d8f87 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, HasMetadata, Pod, PodList} +import io.fabric8.kubernetes.client.{Watch, Watcher} +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} + +object Fabric8Aliases { + type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + type SINGLE_POD = PodResource[Pod, DoneablePod] + type RESOURCE_LIST = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ + HasMetadata, Boolean] +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index f10202f7a3546..e3c19cdb81567 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -22,7 +22,7 @@ import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.deploy.k8s.submit._ class KubernetesConfSuite extends SparkFunSuite { @@ -40,6 +40,9 @@ class KubernetesConfSuite extends SparkFunSuite { private val SECRET_NAMES_TO_MOUNT_PATHS = Map( "secret1" -> "/mnt/secrets/secret1", "secret2" -> "/mnt/secrets/secret2") + private val SECRET_ENV_VARS = Map( + "envName1" -> "name1:key1", + "envName2" -> "name2:key2") private val CUSTOM_ENVS = Map( "customEnvKey1" -> "customEnvValue1", "customEnvKey2" -> "customEnvValue2") @@ -53,9 +56,10 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(conf.appId === APP_ID) assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) @@ -76,7 +80,8 @@ class KubernetesConfSuite extends SparkFunSuite { APP_ID, mainAppJar, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") .split(",") === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) @@ -85,15 +90,81 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithoutMainJar.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.1) } - test("Resolve driver labels, annotations, secret mount paths, and envs.") { + test("Creating driver conf with a python primary file") { + val mainResourceFile = "local:///opt/spark/main.py" + val inputPyFiles = Array("local:///opt/spark/example2.py", "local:///example3.py") val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + .set("spark.files", "local:///opt/spark/example4.py") + val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) + val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + Some(inputPyFiles.mkString(","))) + assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) + assert(kubernetesConfWithMainResource.sparkFiles + === Array("local:///opt/spark/example4.py", mainResourceFile) ++ inputPyFiles) + } + + test("Creating driver conf with a r primary file") { + val mainResourceFile = "local:///opt/spark/main.R" + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + .set("spark.files", "local:///opt/spark/example2.R") + val mainAppResource = Some(RMainAppResource(mainResourceFile)) + val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + maybePyFiles = None) + assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) + assert(kubernetesConfWithMainResource.sparkFiles + === Array("local:///opt/spark/example2.R", mainResourceFile)) + } + + test("Testing explicit setting of memory overhead on non-JVM tasks") { + val sparkConf = new SparkConf(false) + .set(MEMORY_OVERHEAD_FACTOR, 0.3) + + val mainResourceFile = "local:///opt/spark/main.py" + val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + None) + assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) + } + + test("Resolve driver labels, annotations, secret mount paths, envs, and memory overhead") { + val sparkConf = new SparkConf(false) + .set(MEMORY_OVERHEAD_FACTOR, 0.3) CUSTOM_LABELS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) } @@ -103,6 +174,9 @@ class KubernetesConfSuite extends SparkFunSuite { SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX$key", value) + } CUSTOM_ENVS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) } @@ -112,16 +186,19 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(conf.roleLabels === Map( SPARK_APP_ID_LABEL -> APP_ID, SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) assert(conf.roleEnvs === CUSTOM_ENVS) + assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) } test("Basic executor translated fields.") { @@ -129,9 +206,9 @@ class KubernetesConfSuite extends SparkFunSuite { new SparkConf(false), EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) - assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + assert(conf.roleSpecificConf.driverPod.get === DRIVER_POD) } test("Image pull secrets.") { @@ -140,7 +217,7 @@ class KubernetesConfSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.imagePullSecrets() === Seq( new LocalObjectReferenceBuilder().withName("my-secret-1").build(), @@ -155,6 +232,9 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_ANNOTATIONS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX$key", value) + } SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) } @@ -163,13 +243,13 @@ class KubernetesConfSuite extends SparkFunSuite { sparkConf, EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.roleLabels === Map( SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, SPARK_APP_ID_LABEL -> APP_ID, SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) } - } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala new file mode 100644 index 0000000000000..d795d159773a8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class KubernetesVolumeUtilsSuite extends SparkFunSuite { + test("Parses hostPath volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === + KubernetesHostPathVolumeConf("/hostPath")) + } + + test("Parses persistentVolumeClaim volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimeName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf("claimeName")) + } + + test("Parses emptyDir volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.options.medium", "medium") + sparkConf.set("test.emptyDir.volumeName.options.sizeLimit", "5G") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(Some("medium"), Some("5G"))) + } + + test("Parses emptyDir volume options can be optional") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(None, None)) + } + + test("Defaults optional readOnly to false") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.mountReadOnly === false) + } + + test("Gracefully fails on missing mount key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mnt.path", "/path") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "emptyDir.volumeName.mount.path") + } + + test("Gracefully fails on missing option key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.pth", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "hostPath.volumeName.options.path") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index eee85b8baa730..d98e113554648 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder +import io.fabric8.kubernetes.api.model.{ContainerPort, ContainerPortBuilder, LocalObjectReferenceBuilder} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource +import org.apache.spark.ui.SparkUI class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -33,6 +36,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val PY_MAIN_CLASS = "org.apache.spark.deploy.PythonRunner" private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" @@ -47,6 +51,12 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { TEST_IMAGE_PULL_SECRETS.map { secret => new LocalObjectReferenceBuilder().withName(secret).build() } + private val emptyDriverSpecificConf = KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS) + test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() @@ -59,17 +69,16 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - None, - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, - DRIVER_ENVS) + Map.empty, + DRIVER_ENVS, + Nil, + Seq.empty[String]) val featureStep = new BasicDriverFeatureStep(kubernetesConf) val basePod = SparkPod.initialPod() @@ -79,6 +88,14 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(configuredPod.container.getImage === "spark-driver:latest") assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + val expectedPortNames = Set( + containerPort(DRIVER_PORT_NAME, DEFAULT_DRIVER_PORT), + containerPort(BLOCK_MANAGER_PORT_NAME, DEFAULT_BLOCKMANAGER_PORT), + containerPort(UI_PORT_NAME, SparkUI.DEFAULT_PORT) + ) + val foundPortNames = configuredPod.container.getPorts.asScala.toSet + assert(expectedPortNames === foundPortNames) + assert(configuredPod.container.getEnv.size === 3) val envs = configuredPod.container .getEnv @@ -109,7 +126,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") - val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, @@ -118,6 +134,52 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) } + test("Check appropriate entrypoint rerouting for various bindings") { + val javaSparkConf = new SparkConf() + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(CONTAINER_IMAGE, "spark-driver:latest") + val pythonSparkConf = new SparkConf() + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(CONTAINER_IMAGE, "spark-driver:latest") + val javaKubernetesConf = KubernetesConf( + javaSparkConf, + KubernetesDriverSpecificConf( + Some(JavaMainAppResource("")), + APP_NAME, + PY_MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Nil, + Seq.empty[String]) + val pythonKubernetesConf = KubernetesConf( + pythonSparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("")), + APP_NAME, + PY_MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Nil, + Seq.empty[String]) + val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) + val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) + val basePod = SparkPod.initialPod() + val configuredJavaPod = javaFeatureStep.configurePod(basePod) + val configuredPythonPod = pythonFeatureStep.configurePod(basePod) + } + test("Additional system properties resolve jars and set cluster-mode confs.") { val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") @@ -128,17 +190,17 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, "spark-driver:latest") val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - None, - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, - Map.empty) + Map.empty, + DRIVER_ENVS, + Nil, + allFiles) + val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( @@ -150,4 +212,11 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") assert(additionalProperties === expectedSparkConf) } + + def containerPort(name: String, portNumber: Int): ContainerPort = + new ContainerPortBuilder() + .withName(name) + .withContainerPort(portNumber) + .withProtocol("TCP") + .build() } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index a764f7630b5c8..95d373f791649 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -81,13 +81,16 @@ class BasicExecutorFeatureStepSuite val step = new BasicExecutorFeatureStep( KubernetesConf( baseConf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), RESOURCE_NAME_PREFIX, APP_ID, LABELS, ANNOTATIONS, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) // The executor pod name and default labels. @@ -118,13 +121,16 @@ class BasicExecutorFeatureStepSuite val step = new BasicExecutorFeatureStep( KubernetesConf( conf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), longPodNamePrefix, APP_ID, LABELS, ANNOTATIONS, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -136,13 +142,16 @@ class BasicExecutorFeatureStepSuite val step = new BasicExecutorFeatureStep( KubernetesConf( conf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), RESOURCE_NAME_PREFIX, APP_ID, LABELS, ANNOTATIONS, Map.empty, - Map("qux" -> "quux"))) + Map.empty, + Map("qux" -> "quux"), + Nil, + Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) checkEnv(executor, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 9f817d3bfc79a..7e916b3854404 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -59,7 +59,10 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) @@ -88,7 +91,10 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -124,7 +130,10 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index c299d56865ec0..8b91e93eecd8c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -65,7 +65,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) @@ -94,7 +97,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX val expectedHostName = s"$expectedServiceName.my-namespace.svc" @@ -113,7 +119,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) val resolvedService = configurationStep .getAdditionalKubernetesResources() .head @@ -141,7 +150,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty), + Map.empty, + Map.empty, + Nil, + Seq.empty[String]), clock) val driverService = configurationStep .getAdditionalKubernetesResources() @@ -166,7 +178,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty), + Map.empty, + Map.empty, + Nil, + Seq.empty[String]), clock) fail("The driver bind address should not be allowed.") } catch { @@ -189,7 +204,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty), + Map.empty, + Map.empty, + Nil, + Seq.empty[String]), clock) fail("The driver host address should not be allowed.") } catch { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala new file mode 100644 index 0000000000000..85c6cb282d2b0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class EnvSecretsFeatureStepSuite extends SparkFunSuite{ + private val KEY_REF_NAME_FOO = "foo" + private val KEY_REF_NAME_BAR = "bar" + private val KEY_REF_KEY_FOO = "key_foo" + private val KEY_REF_KEY_BAR = "key_bar" + private val ENV_NAME_FOO = "MY_FOO" + private val ENV_NAME_BAR = "MY_bar" + + test("sets up all keyRefs") { + val baseDriverPod = SparkPod.initialPod() + val envVarsToKeys = Map( + ENV_NAME_BAR -> s"${KEY_REF_NAME_BAR}:${KEY_REF_KEY_BAR}", + ENV_NAME_FOO -> s"${KEY_REF_NAME_FOO}:${KEY_REF_KEY_FOO}") + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", Some(new PodBuilder().build())), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + Map.empty, + envVarsToKeys, + Map.empty, + Nil, + Seq.empty[String]) + + val step = new EnvSecretsFeatureStep(kubernetesConf) + val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container + + val expectedVars = + Seq(s"${ENV_NAME_BAR}", s"${ENV_NAME_FOO}") + + expectedVars.foreach { envName => + assert(KubernetesFeaturesTestUtils.containerHasEnvVar(driverContainerWithEnvSecrets, envName)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index 27bff74ce38af..f90380e30e52a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -58,4 +60,7 @@ object KubernetesFeaturesTestUtils { .build()) } + def containerHasEnvVar(container: Container, envVarName: String): Boolean = { + container.getEnv.asScala.exists(envVar => envVar.getName == envVarName) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala new file mode 100644 index 0000000000000..a339827b819a9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder} +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + private val defaultLocalDir = "/var/data/default-local-dir" + private var sparkConf: SparkConf = _ + private var kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf] = _ + + before { + val realSparkConf = new SparkConf(false) + sparkConf = Mockito.spy(realSparkConf) + kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + "resource", + "app-id", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + } + + test("Resolve to default local dir if neither env nor configuration are set") { + Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") + Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 1) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath(defaultLocalDir) + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue(defaultLocalDir) + .build()) + } + + test("Use configured local dirs split on comma if provided.") { + Mockito.doReturn("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 2) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.pod.getSpec.getVolumes.get(1) === + new VolumeBuilder() + .withName(s"spark-local-dir-2") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 2) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath("/var/data/my-local-dir-1") + .build()) + assert(configuredPod.container.getVolumeMounts.get(1) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-2") + .withMountPath("/var/data/my-local-dir-2") + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .build()) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9d02f56cc206d..dad610c443acc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -35,13 +35,16 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { val sparkConf = new SparkConf(false) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + KubernetesExecutorSpecificConf("1", Some(new PodBuilder().build())), "resource-name-prefix", "app-id", Map.empty, Map.empty, secretNamesToMountPaths, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) val step = new MountSecretsFeatureStep(kubernetesConf) val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala new file mode 100644 index 0000000000000..d309aa94ec115 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class MountVolumesFeatureStepSuite extends SparkFunSuite { + private val sparkConf = new SparkConf(false) + private val emptyKubernetesConf = KubernetesConf( + sparkConf = sparkConf, + roleSpecificConf = KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + appResourceNamePrefix = "resource", + appId = "app-id", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Nil) + + test("Mounts hostPath volumes") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts pesistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === true) + + } + + test("Mounts emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "Memory") + assert(emptyDir.getSizeLimit.getAmount === "6G") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts emptyDir with no options") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "") + assert(emptyDir.getSizeLimit.getAmount === null) + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts multiple volumes") { + val hpVolumeConf = KubernetesVolumeSpec( + "hpVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val pvcVolumeConf = KubernetesVolumeSpec( + "checkpointVolume", + "/checkpoints", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val volumesConf = hpVolumeConf :: pvcVolumeConf :: Nil + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumesConf) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..18874afe6e53a --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource + +class JavaDriverFeatureStepSuite extends SparkFunSuite { + + test("Java Step modifies container correctly") { + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.jar")), + "test-class", + "java-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Seq.empty[String]) + + val step = new JavaDriverFeatureStep(kubernetesConf) + val driverPod = step.configurePod(baseDriverPod).pod + val driverContainerwithJavaStep = step.configurePod(baseDriverPod).container + assert(driverContainerwithJavaStep.getArgs.size === 7) + val args = driverContainerwithJavaStep + .getArgs.asScala + assert(args === List( + "driver", + "--properties-file", SPARK_CONF_PATH, + "--class", "test-class", + "spark-internal", "5 7")) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..a5dac6869327d --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource + +class PythonDriverFeatureStepSuite extends SparkFunSuite { + + test("Python Step modifies container correctly") { + val expectedMainResource = "/main.py" + val mainResource = "local:///main.py" + val pyFiles = Seq("local:///example2.py", "local:///example3.py") + val expectedPySparkFiles = + "/example2.py:/example3.py" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) + .set(KUBERNETES_PYSPARK_PY_FILES, pyFiles.mkString(",")) + .set("spark.files", "local:///example.py") + .set(PYSPARK_MAJOR_PYTHON_VERSION, "2") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.py")), + "test-app", + "python-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Seq.empty[String]) + + val step = new PythonDriverFeatureStep(kubernetesConf) + val driverPod = step.configurePod(baseDriverPod).pod + val driverContainerwithPySpark = step.configurePod(baseDriverPod).container + assert(driverContainerwithPySpark.getEnv.size === 4) + val envs = driverContainerwithPySpark + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_PYSPARK_PRIMARY) === expectedMainResource) + assert(envs(ENV_PYSPARK_FILES) === expectedPySparkFiles) + assert(envs(ENV_PYSPARK_ARGS) === "5 7") + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "2") + } + test("Python Step testing empty pyfiles") { + val mainResource = "local:///main.py" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) + .set(PYSPARK_MAJOR_PYTHON_VERSION, "3") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.py")), + "test-class-py", + "python-runner", + Seq.empty[String]), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Seq.empty[String]) + val step = new PythonDriverFeatureStep(kubernetesConf) + val driverContainerwithPySpark = step.configurePod(baseDriverPod).container + val args = driverContainerwithPySpark + .getArgs.asScala + assert(driverContainerwithPySpark.getArgs.size === 5) + assert(args === List( + "driver-py", + "--properties-file", SPARK_CONF_PATH, + "--class", "test-class-py")) + val envs = driverContainerwithPySpark + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "3") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..8fdf91ef638f2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.RMainAppResource + +class RDriverFeatureStepSuite extends SparkFunSuite { + + test("R Step modifies container correctly") { + val expectedMainResource = "/main.R" + val mainResource = "local:///main.R" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_R_MAIN_APP_RESOURCE, mainResource) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(RMainAppResource(mainResource)), + "test-app", + "r-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Seq.empty, + sparkFiles = Seq.empty[String]) + + val step = new RDriverFeatureStep(kubernetesConf) + val driverContainerwithR = step.configurePod(baseDriverPod).container + assert(driverContainerwithR.getEnv.size === 2) + val envs = driverContainerwithR + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_R_PRIMARY) === expectedMainResource) + assert(envs(ENV_R_ARGS) === "5 7") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index c1b203e03a357..4d8e79189ff32 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -103,15 +104,11 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { .build() } - private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ - HasMetadata, Boolean] - private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - @Mock private var kubernetesClient: KubernetesClient = _ @Mock - private var podOperations: Pods = _ + private var podOperations: PODS = _ @Mock private var namedPods: PodResource[Pod, DoneablePod] = _ @@ -123,7 +120,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private var driverBuilder: KubernetesDriverBuilder = _ @Mock - private var resourceList: ResourceList = _ + private var resourceList: RESOURCE_LIST = _ private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ @@ -142,7 +139,10 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(POD_NAME)).thenReturn(namedPods) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 161f9afe7bba9..4117c5487a41e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -17,15 +17,23 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val CREDENTIALS_STEP_TYPE = "credentials" private val SERVICE_STEP_TYPE = "service" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" + private val JAVA_STEP_TYPE = "java-bindings" + private val PYSPARK_STEP_TYPE = "pyspark-bindings" + private val R_STEP_TYPE = "r-bindings" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -36,21 +44,45 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val javaStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + JAVA_STEP_TYPE, classOf[JavaDriverFeatureStep]) + + private val pythonStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + PYSPARK_STEP_TYPE, classOf[PythonDriverFeatureStep]) + + private val rStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + R_STEP_TYPE, classOf[RDriverFeatureStep]) + + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, - _ => secretsStep) + _ => secretsStep, + _ => envSecretsStep, + _ => localDirsStep, + _ => mountVolumesStep, + _ => pythonStep, + _ => rStep, + _ => javaStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - None, + Some(JavaMainAppResource("example.jar")), "test-app", "main", Seq.empty), @@ -59,12 +91,17 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE) + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -80,13 +117,129 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), - Map.empty) + Map("EnvName" -> "SecretName:secretKey"), + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply Java step if main resource is none.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply Python step if main resource is python.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("example.py")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + PYSPARK_STEP_TYPE) + } + + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply R step if main resource is R.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + Some(RMainAppResource("example.R")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, - SECRETS_STEP_TYPE) + LOCAL_DIRS_STEP_TYPE, + R_STEP_TYPE) } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..f7721e6fd6388 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import scala.collection.mutable + +class DeterministicExecutorPodsSnapshotsStore extends ExecutorPodsSnapshotsStore { + + private val snapshotsBuffer = mutable.Buffer.empty[ExecutorPodsSnapshot] + private val subscribers = mutable.Buffer.empty[Seq[ExecutorPodsSnapshot] => Unit] + + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + subscribers += onNewSnapshots + } + + override def stop(): Unit = {} + + def notifySubscribers(): Unit = { + subscribers.foreach(_(snapshotsBuffer)) + snapshotsBuffer.clear() + } + + override def updatePod(updatedPod: Pod): Unit = { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + snapshotsBuffer += currentSnapshot + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + snapshotsBuffer += currentSnapshot + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala new file mode 100644 index 0000000000000..c6b667ed85e8c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, Pod, PodBuilder} + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkPod + +object ExecutorLifecycleTestUtils { + + val TEST_SPARK_APP_ID = "spark-app-id" + + def failedExecutorWithoutDeletion(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("failed") + .addNewContainerStatus() + .withName("spark-executor") + .withImage("k8s-spark") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def pendingExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("pending") + .endStatus() + .build() + } + + def runningExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() + } + + def succeededExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("succeeded") + .endStatus() + .build() + } + + def deletedExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewMetadata() + .withNewDeletionTimestamp("523012521") + .endMetadata() + .build() + } + + def unknownExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("unknown") + .endStatus() + .build() + } + + def podWithAttachedContainerForId(executorId: Long): Pod = { + val sparkPod = executorPodWithId(executorId) + val podWithAttachedContainer = new PodBuilder(sparkPod.pod) + .editOrNewSpec() + .addToContainers(sparkPod.container) + .endSpec() + .build() + podWithAttachedContainer + } + + def executorPodWithId(executorId: Long): SparkPod = { + val pod = new PodBuilder() + .withNewMetadata() + .withName(s"spark-executor-$executorId") + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + val container = new ContainerBuilder() + .withName("spark-executor") + .withImage("k8s-spark") + .build() + SparkPod(pod, container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala new file mode 100644 index 0000000000000..e847f8590d353 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{never, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ +import org.apache.spark.util.ManualClock + +class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { + + private val driverPodName = "driver" + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .withUid("driver-pod-uid") + .endMetadata() + .build() + + private val conf = new SparkConf().set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L) + + private var waitForExecutorPodsClock: ManualClock = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var labeledPods: LABELED_PODS = _ + + @Mock + private var driverPodOperations: PodResource[Pod, DoneablePod] = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + + private var podsAllocatorUnderTest: ExecutorPodsAllocator = _ + + before { + MockitoAnnotations.initMocks(this) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) + when(driverPodOperations.get).thenReturn(driverPod) + when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields())) + .thenAnswer(executorPodAnswer()) + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + waitForExecutorPodsClock = new ManualClock(0L) + podsAllocatorUnderTest = new ExecutorPodsAllocator( + conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) + podsAllocatorUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Initially request executors in batches. Do not request another batch if the" + + " first has not finished.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (nextId <- 1 to podAllocationSize) { + verify(podOperations).create(podWithAttachedContainerForId(nextId)) + } + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("Request executors in batches. Allow another batch to be requested if" + + " all pending executors start running.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + snapshotsStore.notifySubscribers() + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations, times(podAllocationSize + 1)).create(any(classOf[Pod])) + } + + test("When a current batch reaches error states immediately, re-request" + + " them on the next batch.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + val failedPod = failedExecutorWithoutDeletion(podAllocationSize) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("When an executor is requested but the API does not report it in a reasonable time, retry" + + " requesting that executor.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + waitForExecutorPodsClock.setTime(podCreationTimeout + 1) + when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods) + snapshotsStore.notifySubscribers() + verify(labeledPods).delete() + verify(podOperations).create(podWithAttachedContainerForId(2)) + } + + private def executorPodAnswer(): Answer[SparkPod] = { + new Answer[SparkPod] { + override def answer(invocation: InvocationOnMock): SparkPod = { + val k8sConf = invocation.getArgumentAt( + 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]]) + executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt) + } + } + } + + private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = + Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any): Boolean = { + if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + false + } else { + val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + val executorSpecificConf = k8sConf.roleSpecificConf + val expectedK8sConf = KubernetesConf.createExecutorConf( + conf, + executorSpecificConf.executorId, + TEST_SPARK_APP_ID, + Some(driverPod)) + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + // Since KubernetesConf.createExecutorConf clones the SparkConf object, force + // deep equality comparison for the SparkConf object and use object equality + // comparison on all other fields. + k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + } + } + }) + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala new file mode 100644 index 0000000000000..562ace9f49d4d --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.CacheBuilder +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter { + + private var namedExecutorPods: mutable.Map[String, PodResource[Pod, DoneablePod]] = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + @Mock + private var schedulerBackend: KubernetesClusterSchedulerBackend = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + private var eventHandlerUnderTest: ExecutorPodsLifecycleManager = _ + + before { + MockitoAnnotations.initMocks(this) + val removedExecutorsCache = CacheBuilder.newBuilder().build[java.lang.Long, java.lang.Long] + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + namedExecutorPods = mutable.Map.empty[String, PodResource[Pod, DoneablePod]] + when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty[String]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) + eventHandlerUnderTest = new ExecutorPodsLifecycleManager( + new SparkConf(), + executorBuilder, + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + eventHandlerUnderTest.start(schedulerBackend) + } + + test("When an executor reaches error states immediately, remove from the scheduler backend.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName)).delete() + } + + test("Don't remove executors twice from Spark but remove from K8s repeatedly.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete() + } + + test("When the scheduler backend lists executor ids that aren't present in the cluster," + + " remove those executors from Spark.") { + when(schedulerBackend.getExecutorIds()).thenReturn(Seq("1")) + val msg = s"The executor with ID 1 was not found in the cluster but we didn't" + + s" get a reason why. Marking the executor as failed. The executor may have been" + + s" deleted but the driver missed the deletion event." + val expectedLossReason = ExecutorExited(-1, exitCausedByApp = false, msg) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + + private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + s""" + |The executor with id $failedExecutorId exited with exit code 1. + |The API gave the following brief reason: ${failedPod.getStatus.getReason} + |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = { + new Answer[PodResource[Pod, DoneablePod]] { + override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = { + val podName = invocation.getArgumentAt(0, classOf[String]) + namedExecutorPods.getOrElseUpdate( + podName, mock(classOf[PodResource[Pod, DoneablePod]])) + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..1b26d6af296a5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit + +import io.fabric8.kubernetes.api.model.PodListBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsPollingSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + private val sparkConf = new SparkConf + + private val pollingInterval = sparkConf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + private var pollingExecutor: DeterministicScheduler = _ + private var pollingSourceUnderTest: ExecutorPodsPollingSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + pollingExecutor = new DeterministicScheduler() + pollingSourceUnderTest = new ExecutorPodsPollingSnapshotSource( + sparkConf, + kubernetesClient, + eventQueue, + pollingExecutor) + pollingSourceUnderTest.start(TEST_SPARK_APP_ID) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + } + + test("Items returned by the API should be pushed to the event queue") { + when(executorRoleLabeledPods.list()) + .thenReturn(new PodListBuilder() + .addToItems( + runningExecutor(1), + runningExecutor(2)) + .build()) + pollingExecutor.tick(pollingInterval, TimeUnit.MILLISECONDS) + verify(eventQueue).replaceSnapshot(Seq(runningExecutor(1), runningExecutor(2))) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala new file mode 100644 index 0000000000000..70e19c904eddb --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsSnapshotSuite extends SparkFunSuite { + + test("States are interpreted correctly from pod metadata.") { + val pods = Seq( + pendingExecutor(0), + runningExecutor(1), + succeededExecutor(2), + failedExecutorWithoutDeletion(3), + deletedExecutor(4), + unknownExecutor(5)) + val snapshot = ExecutorPodsSnapshot(pods) + assert(snapshot.executorPods === + Map( + 0L -> PodPending(pods(0)), + 1L -> PodRunning(pods(1)), + 2L -> PodSucceeded(pods(2)), + 3L -> PodFailed(pods(3)), + 4L -> PodDeleted(pods(4)), + 5L -> PodUnknown(pods(5)))) + } + + test("Updates add new pods for non-matching ids and edit existing pods for matching ids") { + val originalPods = Seq( + pendingExecutor(0), + runningExecutor(1)) + val originalSnapshot = ExecutorPodsSnapshot(originalPods) + val snapshotWithUpdatedPod = originalSnapshot.withUpdate(succeededExecutor(1)) + assert(snapshotWithUpdatedPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)))) + val snapshotWithNewPod = snapshotWithUpdatedPod.withUpdate(pendingExecutor(2)) + assert(snapshotWithNewPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)), + 2L -> PodPending(pendingExecutor(2)))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala new file mode 100644 index 0000000000000..cf54b3c4eb329 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference + +import io.fabric8.kubernetes.api.model.{Pod, PodBuilder} +import org.jmock.lib.concurrent.DeterministicScheduler +import org.scalatest.BeforeAndAfter +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodsSnapshotsStoreSuite extends SparkFunSuite with BeforeAndAfter { + + private var eventBufferScheduler: DeterministicScheduler = _ + private var eventQueueUnderTest: ExecutorPodsSnapshotsStoreImpl = _ + + before { + eventBufferScheduler = new DeterministicScheduler() + eventQueueUnderTest = new ExecutorPodsSnapshotsStoreImpl(eventBufferScheduler) + } + + test("Subscribers get notified of events periodically.") { + val receivedSnapshots1 = mutable.Buffer.empty[ExecutorPodsSnapshot] + val receivedSnapshots2 = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots1 ++= _ + } + eventQueueUnderTest.addSubscriber(2000) { + receivedSnapshots2 ++= _ + } + + eventBufferScheduler.runUntilIdle() + assert(receivedSnapshots1 === Seq(ExecutorPodsSnapshot())) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + pushPodWithIndex(1) + // Force time to move forward so that the buffer is emitted, scheduling the + // processing task on the subscription executor... + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + // ... then actually execute the subscribers. + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + // Don't repeat snapshots + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + pushPodWithIndex(2) + pushPodWithIndex(3) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots1 === receivedSnapshots2) + } + + test("Even without sending events, initially receive an empty buffer.") { + val receivedInitialSnapshot = new AtomicReference[Seq[ExecutorPodsSnapshot]](null) + eventQueueUnderTest.addSubscriber(1000) { + receivedInitialSnapshot.set + } + assert(receivedInitialSnapshot.get == null) + eventBufferScheduler.runUntilIdle() + assert(receivedInitialSnapshot.get === Seq(ExecutorPodsSnapshot())) + } + + test("Replacing the snapshot passes the new snapshot to subscribers.") { + val receivedSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots ++= _ + } + eventQueueUnderTest.updatePod(podWithIndex(1)) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + eventQueueUnderTest.replaceSnapshot(Seq(podWithIndex(2))) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(2))))) + } + + private def pushPodWithIndex(index: Int): Unit = + eventQueueUnderTest.updatePod(podWithIndex(index)) + + private def podWithIndex(index: Int): Pod = + new PodBuilder() + .editOrNewMetadata() + .withName(s"pod-$index") + .addToLabels(SPARK_EXECUTOR_ID_LABEL, index.toString) + .endMetadata() + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..ac1968b4ff810 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var watchConnection: Watch = _ + + private var watch: ArgumentCaptor[Watcher[Pod]] = _ + + private var watchSourceUnderTest: ExecutorPodsWatchSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + watch = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection) + watchSourceUnderTest = new ExecutorPodsWatchSnapshotSource( + eventQueue, kubernetesClient) + watchSourceUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Watch events should be pushed to the snapshots store as snapshot updates.") { + watch.getValue.eventReceived(Action.ADDED, runningExecutor(1)) + watch.getValue.eventReceived(Action.MODIFIED, runningExecutor(2)) + verify(eventQueue).updatePod(runningExecutor(1)) + verify(eventQueue).updatePod(runningExecutor(2)) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 96065e83f069c..52e7a12dbaf06 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -16,85 +16,36 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} -import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.hamcrest.{BaseMatcher, Description, Matcher} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{doNothing, never, times, verify, when} +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{eq => mockitoEq} +import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ -import scala.collection.JavaConverters._ -import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.ThreadUtils +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { - private val APP_ID = "test-spark-app" - private val DRIVER_POD_NAME = "spark-driver-pod" - private val NAMESPACE = "test-namespace" - private val SPARK_DRIVER_HOST = "localhost" - private val SPARK_DRIVER_PORT = 7077 - private val POD_ALLOCATION_INTERVAL = "1m" - private val FIRST_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod1") - .endMetadata() - .withNewSpec() - .withNodeName("node1") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private val SECOND_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod2") - .endMetadata() - .withNewSpec() - .withNodeName("node2") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.101") - .endStatus() - .build() - - private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - private type LABELED_PODS = FilterWatchListDeletable[ - Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] - private type IN_NAMESPACE_PODS = NonNamespaceOperation[ - Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - - @Mock - private var sparkContext: SparkContext = _ - - @Mock - private var listenerBus: LiveListenerBus = _ - - @Mock - private var taskSchedulerImpl: TaskSchedulerImpl = _ + private val requestExecutorsService = new DeterministicScheduler() + private val sparkConf = new SparkConf(false) + .set("spark.executor.instances", "3") @Mock - private var allocatorExecutor: ScheduledExecutorService = _ + private var sc: SparkContext = _ @Mock - private var requestExecutorsService: ExecutorService = _ + private var rpcEnv: RpcEnv = _ @Mock - private var executorBuilder: KubernetesExecutorBuilder = _ + private var driverEndpointRef: RpcEndpointRef = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -103,347 +54,97 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var podOperations: PODS = _ @Mock - private var podsWithLabelOperations: LABELED_PODS = _ + private var labeledPods: LABELED_PODS = _ @Mock - private var podsInNamespace: IN_NAMESPACE_PODS = _ + private var taskScheduler: TaskSchedulerImpl = _ @Mock - private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + private var eventQueue: ExecutorPodsSnapshotsStore = _ @Mock - private var rpcEnv: RpcEnv = _ + private var podAllocator: ExecutorPodsAllocator = _ @Mock - private var driverEndpointRef: RpcEndpointRef = _ + private var lifecycleEventHandler: ExecutorPodsLifecycleManager = _ @Mock - private var executorPodsWatch: Watch = _ + private var watchEvents: ExecutorPodsWatchSnapshotSource = _ @Mock - private var successFuture: Future[Boolean] = _ + private var pollEvents: ExecutorPodsPollingSnapshotSource = _ - private var sparkConf: SparkConf = _ - private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ - private var allocatorRunnable: ArgumentCaptor[Runnable] = _ - private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ - - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(DRIVER_POD_NAME) - .addToLabels(SPARK_APP_ID_LABEL, APP_ID) - .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) - .endMetadata() - .build() + private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ before { MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) - .set(KUBERNETES_NAMESPACE, NAMESPACE) - .set("spark.driver.host", SPARK_DRIVER_HOST) - .set("spark.driver.port", SPARK_DRIVER_PORT.toString) - .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL) - executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) - allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) - requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + when(taskScheduler.sc).thenReturn(sc) + when(sc.conf).thenReturn(sparkConf) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(sparkContext.conf).thenReturn(sparkConf) - when(sparkContext.listenerBus).thenReturn(listenerBus) - when(taskSchedulerImpl.sc).thenReturn(sparkContext) - when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) - when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) - .thenReturn(executorPodsWatch) - when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) - when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) - when(podsWithDriverName.get()).thenReturn(driverPod) - when(allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable.capture(), - mockitoEq(0L), - mockitoEq(TimeUnit.MINUTES.toMillis(1)), - mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null) - // Creating Futures in Scala backed by a Java executor service resolves to running - // ExecutorService#execute (as opposed to submit) - doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) when(rpcEnv.setupEndpoint( mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) .thenReturn(driverEndpointRef) - - // Used by the CoarseGrainedSchedulerBackend when making RPC calls. - when(driverEndpointRef.ask[Boolean] - (any(classOf[Any])) - (any())).thenReturn(successFuture) - when(successFuture.failed).thenReturn(Future[Throwable] { - // emulate behavior of the Future.failed method. - throw new NoSuchElementException() - }(ThreadUtils.sameThread)) - } - - test("Basic lifecycle expectations when starting and stopping the scheduler.") { - val scheduler = newSchedulerBackend() - scheduler.start() - assert(executorPodsWatcherArgument.getValue != null) - assert(allocatorRunnable.getValue != null) - scheduler.stop() - verify(executorPodsWatch).close() - } - - test("Static allocation should request executors upon first allocator run.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations).create(secondResolvedPod) - } - - test("Killing executors deletes the executor pods") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - scheduler.doKillExecutors(Seq("2")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations).delete(secondResolvedPod) - verify(podOperations, never()).delete(firstResolvedPod) - } - - test("Executors should be requested in batches.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations, never()).create(secondResolvedPod) - val registerFirstExecutorMessage = RegisterExecutor( - "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - allocatorRunnable.getValue.run() - verify(podOperations).create(secondResolvedPod) - } - - test("Scaled down executors should be cleaned up") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - - // The scheduler backend spins up one executor pod. - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - // Request that there are 0 executors and trigger deletion from driver. - scheduler.doRequestTotalExecutors(0) - requestExecutorRunnable.getAllValues.asScala.last.run() - scheduler.doKillExecutors(Seq("1")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations, times(1)).delete(resolvedPod) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - - val exitedPod = exitPod(resolvedPod, 0) - executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) - allocatorRunnable.getValue.run() - - // No more deletion attempts of the executors. - // This is graceful termination and should not be detected as a failure. - verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).send( - RemoveExecutor("1", ExecutorExited( - 0, - exitCausedByApp = false, - s"Container in pod ${exitedPod.getMetadata.getName} exited from" + - s" explicit termination request."))) - } - - test("Executors that fail should not be deleted.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - executorPodsWatcherArgument.getValue.eventReceived( - Action.ERROR, exitPod(firstResolvedPod, 1)) - - // A replacement executor should be created but the error pod should persist. - val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - scheduler.doRequestTotalExecutors(1) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getAllValues.asScala.last.run() - verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", ExecutorExited( - 1, - exitCausedByApp = true, - s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + - " exit status code 1."))) - } - - test("Executors disconnected due to unknown reasons are deleted and replaced.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val executorLostReasonCheckMaxAttempts = sparkConf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - 1 to executorLostReasonCheckMaxAttempts foreach { _ => - allocatorRunnable.getValue.run() - verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + when(kubernetesClient.pods()).thenReturn(podOperations) + schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( + taskScheduler, + rpcEnv, + kubernetesClient, + requestExecutorsService, + eventQueue, + podAllocator, + lifecycleEventHandler, + watchEvents, + pollEvents) { + override def applicationId(): String = TEST_SPARK_APP_ID } - - val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } - test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(firstResolvedPod)) - .thenThrow(new RuntimeException("test")) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Start all components") { + schedulerBackendUnderTest.start() + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator).start(TEST_SPARK_APP_ID) + verify(lifecycleEventHandler).start(schedulerBackendUnderTest) + verify(watchEvents).start(TEST_SPARK_APP_ID) + verify(pollEvents).start(TEST_SPARK_APP_ID) } - test("Executors that are initially created but the watch notices them fail are rebuilt" + - " in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Stop all components") { + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + schedulerBackendUnderTest.stop() + verify(eventQueue).stop() + verify(watchEvents).stop() + verify(pollEvents).stop() + verify(labeledPods).delete() + verify(kubernetesClient).close() } - private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { - new KubernetesClusterSchedulerBackend( - taskSchedulerImpl, - rpcEnv, - executorBuilder, - kubernetesClient, - allocatorExecutor, - requestExecutorsService) { - - override def applicationId(): String = APP_ID - } + test("Remove executor") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRemoveExecutor( + "1", ExecutorKilled) + verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } - private def exitPod(basePod: Pod, exitCode: Int): Pod = { - new PodBuilder(basePod) - .editStatus() - .addNewContainerStatus() - .withNewState() - .withNewTerminated() - .withExitCode(exitCode) - .endTerminated() - .endState() - .endContainerStatus() - .endStatus() - .build() + test("Kill executors") { + schedulerBackendUnderTest.start() + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + when(labeledPods.withLabelIn(SPARK_EXECUTOR_ID_LABEL, "1", "2")).thenReturn(labeledPods) + schedulerBackendUnderTest.doKillExecutors(Seq("1", "2")) + verify(labeledPods, never()).delete() + requestExecutorsService.runNextPendingCommand() + verify(labeledPods).delete() } - private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { - val resolvedPod = new PodBuilder(expectedPod) - .editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) - .endMetadata() - .build() - val resolvedContainer = new ContainerBuilder().build() - when(executorBuilder.buildFromFeatures(Matchers.argThat( - new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { - override def matches(argument: scala.Any) - : Boolean = { - argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && - argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] - .roleSpecificConf.executorId == executorId.toString - } - - override def describeTo(description: Description): Unit = {} - }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) - new PodBuilder(resolvedPod) - .editSpec() - .addToContainers(resolvedContainer) - .endSpec() - .build() + test("Request total executors") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRequestTotalExecutors(5) + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator, never()).setTotalExpectedExecutors(5) + requestExecutorsService.runNextPendingCommand() + verify(podAllocator).setTotalExpectedExecutors(5) } + } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index f5270623f8acc..44fe4a24e1102 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -19,51 +19,98 @@ package org.apache.spark.scheduler.cluster.k8s import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, - _ => mountSecretsStep) + _ => mountSecretsStep, + _ => envSecretsStep, + _ => localDirsStep, + _ => mountVolumesStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( new SparkConf(false), KubernetesExecutorSpecificConf( - "executor-id", new PodBuilder().build()), + "executor-id", Some(new PodBuilder().build())), "prefix", "appId", Map.empty, Map.empty, Map.empty, - Map.empty) - validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { val conf = KubernetesConf( new SparkConf(false), KubernetesExecutorSpecificConf( - "executor-id", new PodBuilder().build()), + "executor-id", Some(new PodBuilder().build())), "prefix", "appId", Map.empty, Map.empty, Map("secret" -> "secretMountPath"), - Map.empty) + Map("secret-name" -> "secret-key"), + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE) + } + + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/checkpoint")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", Some(new PodBuilder().build())), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, - SECRETS_STEP_TYPE) + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE) } private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 9badf8556afc3..42a670174eae1 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -31,7 +31,7 @@ RUN set -ex && \ apk upgrade --no-cache && \ apk add --no-cache bash tini libc6-compat && \ mkdir -p /opt/spark && \ - mkdir -p /opt/spark/work-dir \ + mkdir -p /opt/spark/work-dir && \ touch /opt/spark/RELEASE && \ rm /bin/sh && \ ln -sv /bin/bash /bin/sh && \ diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile new file mode 100644 index 0000000000000..e627883ba782e --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG base_img +FROM $base_img +WORKDIR / +RUN mkdir ${SPARK_HOME}/R +COPY R ${SPARK_HOME}/R + +RUN apk add --no-cache R R-dev + +ENV R_HOME /usr/lib/R + +WORKDIR /opt/spark/work-dir +ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile new file mode 100644 index 0000000000000..72bb9620b45de --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG base_img +FROM $base_img +WORKDIR / +RUN mkdir ${SPARK_HOME}/python +COPY python/lib ${SPARK_HOME}/python/lib +# TODO: Investigate running both pip and pip3 via virtualenvs +RUN apk add --no-cache python && \ + apk add --no-cache python3 && \ + python -m ensurepip && \ + python3 -m ensurepip && \ + # We remove ensurepip since it adds no functionality since pip is + # installed on the image and it just takes up 1.6MB on the image + rm -r /usr/lib/python*/ensurepip && \ + pip install --upgrade pip setuptools && \ + # You may install with python3 packages by using pip3.6 + # Removed the .cache to save space + rm -r /root/.cache + +ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-*.zip + +WORKDIR /opt/spark/work-dir +ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 3e166116aa3fd..216e8fe31becb 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -37,20 +37,50 @@ if [ -z "$uidentry" ] ; then fi SPARK_K8S_CMD="$1" -if [ -z "$SPARK_K8S_CMD" ]; then - echo "No command to execute has been provided." 1>&2 - exit 1 -fi -shift 1 +case "$SPARK_K8S_CMD" in + driver | driver-py | driver-r | executor) + shift 1 + ;; + "") + ;; + *) + echo "Non-spark-on-k8s command provided, proceeding in pass-through mode..." + exec /sbin/tini -s -- "$@" + ;; +esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt -if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then - SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" +readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt + +if [ -n "$SPARK_EXTRA_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_EXTRA_CLASSPATH" +fi + +if [ -n "$PYSPARK_FILES" ]; then + PYTHONPATH="$PYTHONPATH:$PYSPARK_FILES" fi -if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then - cp -R "$SPARK_MOUNTED_FILES_DIR/." . + +PYSPARK_ARGS="" +if [ -n "$PYSPARK_APP_ARGS" ]; then + PYSPARK_ARGS="$PYSPARK_APP_ARGS" +fi + +R_ARGS="" +if [ -n "$R_APP_ARGS" ]; then + R_ARGS="$R_APP_ARGS" +fi + +if [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "2" ]; then + pyv="$(python -V 2>&1)" + export PYTHON_VERSION="${pyv:7}" + export PYSPARK_PYTHON="python" + export PYSPARK_DRIVER_PYTHON="python" +elif [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "3" ]; then + pyv3="$(python3 -V 2>&1)" + export PYTHON_VERSION="${pyv3:7}" + export PYSPARK_PYTHON="python3" + export PYSPARK_DRIVER_PYTHON="python3" fi case "$SPARK_K8S_CMD" in @@ -62,11 +92,26 @@ case "$SPARK_K8S_CMD" in "$@" ) ;; - + driver-py) + CMD=( + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS + ) + ;; + driver-r) + CMD=( + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" $R_PRIMARY $R_ARGS + ) + ;; executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md new file mode 100644 index 0000000000000..b3863e6b7d1af --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -0,0 +1,52 @@ +--- +layout: global +title: Spark on Kubernetes Integration Tests +--- + +# Running the Kubernetes Integration Tests + +Note that the integration test framework is currently being heavily revised and +is subject to change. Note that currently the integration tests only run with Java 8. + +The simplest way to run the integration tests is to install and run Minikube, then run the following: + + dev/dev-run-integration-tests.sh + +The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should +run with a minimum of 3 CPUs and 4G of memory: + + minikube start --cpus 3 --memory 4096 + +You can download Minikube [here](https://github.com/kubernetes/minikube/releases). + +# Integration test customization + +Configuration of the integration test runtime is done through passing different arguments to the test script. The main useful options are outlined below. + +## Re-using Docker Images + +By default, the test framework will build new Docker images on every test execution. A unique image tag is generated, +and it is written to file at `target/imageTag.txt`. To reuse the images built in a previous run, or to use a Docker image tag +that you have built by other means already, pass the tag to the test script: + + dev/dev-run-integration-tests.sh --image-tag + +where if you still want to use images that were built before by the test framework: + + dev/dev-run-integration-tests.sh --image-tag $(cat target/imageTag.txt) + +## Spark Distribution Under Test + +The Spark code to test is handed to the integration test system via a tarball. Here is the option that is used to specify the tarball: + +* `--spark-tgz ` - set `` to point to a tarball containing the Spark distribution to test. + +TODO: Don't require the packaging of the built Spark artifacts into this tarball, just read them out of the current tree. + +## Customizing the Namespace and Service Account + +* `--namespace ` - set `` to the namespace in which the tests should be run. +* `--service-account ` - set `` to the name of the Kubernetes service account to +use in the namespace specified by the `--namespace`. The service account is expected to have permissions to get, list, watch, +and create pods. For clusters with RBAC turned on, it's important that the right permissions are granted to the service account +in the namespace through an appropriate role and role binding. A reference RBAC configuration is provided in `dev/spark-rbac.yaml`. diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh new file mode 100755 index 0000000000000..b28b8b82ca016 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -0,0 +1,106 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +set -xo errexit +TEST_ROOT_DIR=$(git rev-parse --show-toplevel) + +DEPLOY_MODE="minikube" +IMAGE_REPO="docker.io/kubespark" +SPARK_TGZ="N/A" +IMAGE_TAG="N/A" +SPARK_MASTER= +NAMESPACE= +SERVICE_ACCOUNT= +INCLUDE_TAGS="k8s" +EXCLUDE_TAGS= + +# Parse arguments +while (( "$#" )); do + case $1 in + --image-repo) + IMAGE_REPO="$2" + shift + ;; + --image-tag) + IMAGE_TAG="$2" + shift + ;; + --deploy-mode) + DEPLOY_MODE="$2" + shift + ;; + --spark-tgz) + SPARK_TGZ="$2" + shift + ;; + --spark-master) + SPARK_MASTER="$2" + shift + ;; + --namespace) + NAMESPACE="$2" + shift + ;; + --service-account) + SERVICE_ACCOUNT="$2" + shift + ;; + --include-tags) + INCLUDE_TAGS="k8s,$2" + shift + ;; + --exclude-tags) + EXCLUDE_TAGS="$2" + shift + ;; + *) + break + ;; + esac + shift +done + +properties=( + -Dspark.kubernetes.test.sparkTgz=$SPARK_TGZ \ + -Dspark.kubernetes.test.imageTag=$IMAGE_TAG \ + -Dspark.kubernetes.test.imageRepo=$IMAGE_REPO \ + -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE \ + -Dtest.include.tags=$INCLUDE_TAGS +) + +if [ -n $NAMESPACE ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE ) +fi + +if [ -n $SERVICE_ACCOUNT ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.serviceAccountName=$SERVICE_ACCOUNT ) +fi + +if [ -n $SPARK_MASTER ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) +fi + +if [ -n $EXCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) +fi + +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pkubernetes -Phadoop-2.7 ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml new file mode 100644 index 0000000000000..a4c242f2f2645 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +apiVersion: v1 +kind: Namespace +metadata: + name: spark +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: spark-sa + namespace: spark +--- +apiVersion: rbac.authorization.k8s.io/v1beta1 +kind: ClusterRole +metadata: + name: spark-role +rules: +- apiGroups: + - "" + resources: + - "pods" + verbs: + - "*" +--- +apiVersion: rbac.authorization.k8s.io/v1beta1 +kind: ClusterRoleBinding +metadata: + name: spark-role-binding +subjects: +- kind: ServiceAccount + name: spark-sa + namespace: spark +roleRef: + kind: ClusterRole + name: spark-role + apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml new file mode 100644 index 0000000000000..614705c1ed668 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -0,0 +1,170 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes-integration-tests_2.11 + + 1.3.0 + 1.4.0 + + 3.0.0 + 3.2.2 + 1.0 + kubernetes-integration-tests + ${project.build.directory}/spark-dist-unpacked + N/A + ${project.build.directory}/imageTag.txt + minikube + docker.io/kubespark + + + + jar + Spark Project Kubernetes Integration Tests + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + io.fabric8 + kubernetes-client + ${kubernetes-client.version} + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + + + + + + + org.codehaus.mojo + exec-maven-plugin + ${exec-maven-plugin.version} + + + setup-integration-test-env + pre-integration-test + + exec + + + scripts/setup-integration-test-env.sh + + --unpacked-spark-tgz + ${spark.kubernetes.test.unpackSparkDir} + + --image-repo + ${spark.kubernetes.test.imageRepo} + + --image-tag + ${spark.kubernetes.test.imageTag} + + --image-tag-output-file + ${spark.kubernetes.test.imageTagFile} + + --deploy-mode + ${spark.kubernetes.test.deployMode} + + --spark-tgz + ${spark.kubernetes.test.sparkTgz} + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + + + org.scalatest + scalatest-maven-plugin + ${scalatest-maven-plugin.version} + + ${project.build.directory}/surefire-reports + . + SparkTestSuite.txt + -ea -Xmx3g -XX:ReservedCodeCacheSize=512m ${extraScalaTestArgs} + + + file:src/test/resources/log4j.properties + true + ${spark.kubernetes.test.imageTagFile} + ${spark.kubernetes.test.unpackSparkDir} + ${spark.kubernetes.test.imageRepo} + ${spark.kubernetes.test.deployMode} + ${spark.kubernetes.test.master} + ${spark.kubernetes.test.namespace} + ${spark.kubernetes.test.serviceAccountName} + + ${test.exclude.tags} + ${test.include.tags} + + + + test + + test + + + + (?<!Suite) + + + + integration-test + integration-test + + test + + + + + + + + + diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh new file mode 100755 index 0000000000000..ccfb8e767c529 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +TEST_ROOT_DIR=$(git rev-parse --show-toplevel) +UNPACKED_SPARK_TGZ="$TEST_ROOT_DIR/target/spark-dist-unpacked" +IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" +DEPLOY_MODE="minikube" +IMAGE_REPO="docker.io/kubespark" +IMAGE_TAG="N/A" +SPARK_TGZ="N/A" + +# Parse arguments +while (( "$#" )); do + case $1 in + --unpacked-spark-tgz) + UNPACKED_SPARK_TGZ="$2" + shift + ;; + --image-repo) + IMAGE_REPO="$2" + shift + ;; + --image-tag) + IMAGE_TAG="$2" + shift + ;; + --image-tag-output-file) + IMAGE_TAG_OUTPUT_FILE="$2" + shift + ;; + --deploy-mode) + DEPLOY_MODE="$2" + shift + ;; + --spark-tgz) + SPARK_TGZ="$2" + shift + ;; + *) + break + ;; + esac + shift +done + +if [[ $SPARK_TGZ == "N/A" ]]; +then + echo "Must specify a Spark tarball to build Docker images against with --spark-tgz." && exit 1; +fi + +rm -rf $UNPACKED_SPARK_TGZ +mkdir -p $UNPACKED_SPARK_TGZ +tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; + +if [[ $IMAGE_TAG == "N/A" ]]; +then + IMAGE_TAG=$(uuidgen); + cd $UNPACKED_SPARK_TGZ + if [[ $DEPLOY_MODE == cloud ]] ; + then + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + if [[ $IMAGE_REPO == gcr.io* ]] ; + then + gcloud docker -- push $IMAGE_REPO/spark:$IMAGE_TAG + else + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push + fi + else + # -m option for minikube. + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build + fi + cd - +fi + +rm -f $IMAGE_TAG_OUTPUT_FILE +echo -n $IMAGE_TAG > $IMAGE_TAG_OUTPUT_FILE diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..866126bc3c1c2 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/integration-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/integration-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala new file mode 100644 index 0000000000000..4e749c40563dc --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.launcher.SparkLauncher + +private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite => + + import BasicTestsSuite._ + import KubernetesSuite.k8sTestTag + + test("Run SparkPi with no resources", k8sTestTag) { + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with a very long application name.", k8sTestTag) { + sparkAppConf.set("spark.app.name", "long" * 40) + runSparkPiAndVerifyCompletion() + } + + test("Use SparkLauncher.NO_RESOURCE", k8sTestTag) { + sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) + runSparkPiAndVerifyCompletion( + appResource = SparkLauncher.NO_RESOURCE) + } + + test("Run SparkPi with a master URL without a scheme.", k8sTestTag) { + val url = kubernetesTestComponents.kubernetesClient.getMasterUrl + val k8sMasterUrl = if (url.getPort < 0) { + s"k8s://${url.getHost}" + } else { + s"k8s://${url.getHost}:${url.getPort}" + } + sparkAppConf.set("spark.master", k8sMasterUrl) + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with an argument.", k8sTestTag) { + runSparkPiAndVerifyCompletion(appArgs = Array("5")) + } + + test("Run SparkPi with custom labels, annotations, and environment variables.", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.driver.label.label1", "label1-value") + .set("spark.kubernetes.driver.label.label2", "label2-value") + .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") + .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") + .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") + .set("spark.kubernetes.executor.label.label1", "label1-value") + .set("spark.kubernetes.executor.label.label2", "label2-value") + .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") + .set("spark.executorEnv.ENV1", "VALUE1") + .set("spark.executorEnv.ENV2", "VALUE2") + + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkCustomSettings(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkCustomSettings(executorPod) + }) + } + + test("Run extraJVMOptions check on driver", k8sTestTag) { + sparkAppConf + .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") + runSparkJVMCheckAndVerifyCompletion( + expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) + } + + test("Run SparkRemoteFileTest using a remote data file", k8sTestTag) { + sparkAppConf + .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + runSparkRemoteCheckAndVerifyCompletion(appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) + } +} + +private[spark] object BasicTestsSuite { + val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" + val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = + s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" + val REMOTE_PAGE_RANK_DATA_FILE = + "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala new file mode 100644 index 0000000000000..c8bd584516ea5 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.scalatest.concurrent.Eventually +import scala.collection.JavaConverters._ + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.{k8sTestTag, INTERVAL, TIMEOUT} + +private[spark] trait ClientModeTestsSuite { k8sSuite: KubernetesSuite => + + test("Run in client mode.", k8sTestTag) { + val labels = Map("spark-app-selector" -> driverPodName) + val driverPort = 7077 + val blockManagerPort = 10000 + val driverService = testBackend + .getKubernetesClient + .services() + .inNamespace(kubernetesTestComponents.namespace) + .createNew() + .withNewMetadata() + .withName(s"$driverPodName-svc") + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(labels.asJava) + .addNewPort() + .withName("driver-port") + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName("block-manager") + .withPort(blockManagerPort) + .withNewTargetPort(blockManagerPort) + .endPort() + .endSpec() + .done() + try { + val driverPod = testBackend + .getKubernetesClient + .pods() + .inNamespace(kubernetesTestComponents.namespace) + .createNew() + .withNewMetadata() + .withName(driverPodName) + .withLabels(labels.asJava) + .endMetadata() + .withNewSpec() + .withServiceAccountName(kubernetesTestComponents.serviceAccountName) + .addNewContainer() + .withName("spark-example") + .withImage(image) + .withImagePullPolicy("IfNotPresent") + .withCommand("/opt/spark/bin/run-example") + .addToArgs("--master", s"k8s://https://kubernetes.default.svc") + .addToArgs("--deploy-mode", "client") + .addToArgs("--conf", s"spark.kubernetes.container.image=$image") + .addToArgs( + "--conf", + s"spark.kubernetes.namespace=${kubernetesTestComponents.namespace}") + .addToArgs("--conf", "spark.kubernetes.authenticate.oauthTokenFile=" + + "/var/run/secrets/kubernetes.io/serviceaccount/token") + .addToArgs("--conf", "spark.kubernetes.authenticate.caCertFile=" + + "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt") + .addToArgs("--conf", s"spark.kubernetes.driver.pod.name=$driverPodName") + .addToArgs("--conf", "spark.executor.memory=500m") + .addToArgs("--conf", "spark.executor.cores=1") + .addToArgs("--conf", "spark.executor.instances=1") + .addToArgs("--conf", + s"spark.driver.host=" + + s"${driverService.getMetadata.getName}.${kubernetesTestComponents.namespace}.svc") + .addToArgs("--conf", s"spark.driver.port=$driverPort") + .addToArgs("--conf", s"spark.driver.blockManager.port=$blockManagerPort") + .addToArgs("SparkPi") + .addToArgs("10") + .endContainer() + .endSpec() + .done() + Eventually.eventually(TIMEOUT, INTERVAL) { + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPodName) + .getLog + .contains("Pi is roughly 3"), "The application did not complete.") + } + } finally { + // Have to delete the service manually since it doesn't have an owner reference + kubernetesTestComponents + .kubernetesClient + .services() + .inNamespace(kubernetesTestComponents.namespace) + .delete(driverService) + } + } + +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala new file mode 100644 index 0000000000000..896a83a5badbb --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File +import java.nio.file.{Path, Paths} +import java.util.UUID +import java.util.regex.Pattern + +import com.google.common.io.PatternFilenameFilter +import io.fabric8.kubernetes.api.model.Pod +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} +import org.scalatest.concurrent.{Eventually, PatienceConfiguration} +import org.scalatest.time.{Minutes, Seconds, Span} +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.integrationtest.TestConfig._ +import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} + +private[spark] class KubernetesSuite extends SparkFunSuite + with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite + with PythonTestsSuite with ClientModeTestsSuite { + + import KubernetesSuite._ + + private var sparkHomeDir: Path = _ + private var pyImage: String = _ + private var rImage: String = _ + + protected var image: String = _ + protected var testBackend: IntegrationTestBackend = _ + protected var driverPodName: String = _ + protected var kubernetesTestComponents: KubernetesTestComponents = _ + protected var sparkAppConf: SparkAppConf = _ + protected var containerLocalSparkDistroExamplesJar: String = _ + protected var appLocator: String = _ + + override def beforeAll(): Unit = { + // The scalatest-maven-plugin gives system properties that are referenced but not set null + // values. We need to remove the null-value properties before initializing the test backend. + val nullValueProperties = System.getProperties.asScala + .filter(entry => entry._2.equals("null")) + .map(entry => entry._1.toString) + nullValueProperties.foreach { key => + System.clearProperty(key) + } + + val sparkDirProp = System.getProperty("spark.kubernetes.test.unpackSparkDir") + require(sparkDirProp != null, "Spark home directory must be provided in system properties.") + sparkHomeDir = Paths.get(sparkDirProp) + require(sparkHomeDir.toFile.isDirectory, + s"No directory found for spark home specified at $sparkHomeDir.") + val imageTag = getTestImageTag + val imageRepo = getTestImageRepo + image = s"$imageRepo/spark:$imageTag" + pyImage = s"$imageRepo/spark-py:$imageTag" + rImage = s"$imageRepo/spark-r:$imageTag" + + val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) + .toFile + .listFiles(new PatternFilenameFilter(Pattern.compile("^spark-examples_.*\\.jar$")))(0) + containerLocalSparkDistroExamplesJar = s"local:///opt/spark/examples/jars/" + + s"${sparkDistroExamplesJarFile.getName}" + testBackend = IntegrationTestBackendFactory.getTestBackend + testBackend.initialize() + kubernetesTestComponents = new KubernetesTestComponents(testBackend.getKubernetesClient) + } + + override def afterAll(): Unit = { + testBackend.cleanUp() + } + + before { + appLocator = UUID.randomUUID().toString.replaceAll("-", "") + driverPodName = "spark-test-app-" + UUID.randomUUID().toString.replaceAll("-", "") + sparkAppConf = kubernetesTestComponents.newSparkAppConf() + .set("spark.kubernetes.container.image", image) + .set("spark.kubernetes.driver.pod.name", driverPodName) + .set("spark.kubernetes.driver.label.spark-app-locator", appLocator) + .set("spark.kubernetes.executor.label.spark-app-locator", appLocator) + if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { + kubernetesTestComponents.createNamespace() + } + } + + after { + if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { + kubernetesTestComponents.deleteNamespace() + } + deleteDriverPod() + } + + protected def runSparkPiAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, + appArgs: Array[String] = Array.empty[String], + appLocator: String = appLocator, + isJVM: Boolean = true ): Unit = { + runSparkApplicationAndVerifyCompletion( + appResource, + SPARK_PI_MAIN_CLASS, + Seq("Pi is roughly 3"), + appArgs, + driverPodChecker, + executorPodChecker, + appLocator, + isJVM) + } + + protected def runSparkRemoteCheckAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, + appArgs: Array[String], + appLocator: String = appLocator): Unit = { + runSparkApplicationAndVerifyCompletion( + appResource, + SPARK_REMOTE_MAIN_CLASS, + Seq(s"Mounting of ${appArgs.head} was true"), + appArgs, + driverPodChecker, + executorPodChecker, + appLocator, + true) + } + + protected def runSparkJVMCheckAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + mainClass: String = SPARK_DRIVER_MAIN_CLASS, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + appArgs: Array[String] = Array("5"), + expectedJVMValue: Seq[String]): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + true) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + doBasicDriverPodCheck(driverPod) + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedJVMValue.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } + } + + protected def runSparkApplicationAndVerifyCompletion( + appResource: String, + mainClass: String, + expectedLogOnCompletion: Seq[String], + appArgs: Array[String], + driverPodChecker: Pod => Unit, + executorPodChecker: Pod => Unit, + appLocator: String, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + isJVM, + pyFiles) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + driverPodChecker(driverPod) + + val executorPods = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "executor") + .list() + .getItems + executorPods.asScala.foreach { pod => + executorPodChecker(pod) + } + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedLogOnCompletion.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } + } + + protected def doBasicDriverPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === image) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + + + protected def doBasicDriverPyPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + + protected def doBasicDriverRPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === rImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + + + protected def doBasicExecutorPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === image) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + + protected def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + + protected def doBasicExecutorRPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === rImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + + protected def checkCustomSettings(pod: Pod): Unit = { + assert(pod.getMetadata.getLabels.get("label1") === "label1-value") + assert(pod.getMetadata.getLabels.get("label2") === "label2-value") + assert(pod.getMetadata.getAnnotations.get("annotation1") === "annotation1-value") + assert(pod.getMetadata.getAnnotations.get("annotation2") === "annotation2-value") + + val container = pod.getSpec.getContainers.get(0) + val envVars = container + .getEnv + .asScala + .map { env => + (env.getName, env.getValue) + } + .toMap + assert(envVars("ENV1") === "VALUE1") + assert(envVars("ENV2") === "VALUE2") + } + + private def deleteDriverPod(): Unit = { + kubernetesTestComponents.kubernetesClient.pods().withName(driverPodName).delete() + Eventually.eventually(TIMEOUT, INTERVAL) { + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPodName) + .get() == null) + } + } +} + +private[spark] object KubernetesSuite { + val k8sTestTag = Tag("k8s") + val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" + val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" + val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" + val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) + val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala new file mode 100644 index 0000000000000..b602fdf39731f --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.nio.file.{Path, Paths} +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.client.DefaultKubernetesClient +import org.scalatest.concurrent.Eventually + +import org.apache.spark.internal.Logging + +private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesClient) { + + val namespaceOption = Option(System.getProperty("spark.kubernetes.test.namespace")) + val hasUserSpecifiedNamespace = namespaceOption.isDefined + val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", "")) + val serviceAccountName = + Option(System.getProperty("spark.kubernetes.test.serviceAccountName")) + .getOrElse("default") + val kubernetesClient = defaultClient.inNamespace(namespace) + val clientConfig = kubernetesClient.getConfiguration + + def createNamespace(): Unit = { + defaultClient.namespaces.createNew() + .withNewMetadata() + .withName(namespace) + .endMetadata() + .done() + } + + def deleteNamespace(): Unit = { + defaultClient.namespaces.withName(namespace).delete() + Eventually.eventually(KubernetesSuite.TIMEOUT, KubernetesSuite.INTERVAL) { + val namespaceList = defaultClient + .namespaces() + .list() + .getItems + .asScala + require(!namespaceList.exists(_.getMetadata.getName == namespace)) + } + } + + def newSparkAppConf(): SparkAppConf = { + new SparkAppConf() + .set("spark.master", s"k8s://${kubernetesClient.getMasterUrl}") + .set("spark.kubernetes.namespace", namespace) + .set("spark.executor.memory", "500m") + .set("spark.executor.cores", "1") + .set("spark.executors.instances", "1") + .set("spark.app.name", "spark-test-app") + .set("spark.ui.enabled", "true") + .set("spark.testing", "false") + .set("spark.kubernetes.submission.waitAppCompletion", "false") + .set("spark.kubernetes.authenticate.driver.serviceAccountName", serviceAccountName) + } +} + +private[spark] class SparkAppConf { + + private val map = mutable.Map[String, String]() + + def set(key: String, value: String): SparkAppConf = { + map.put(key, value) + this + } + + def get(key: String): String = map.getOrElse(key, "") + + def setJars(jars: Seq[String]): Unit = set("spark.jars", jars.mkString(",")) + + override def toString: String = map.toString + + def toStringArray: Iterable[String] = map.toList.flatMap(t => List("--conf", s"${t._1}=${t._2}")) +} + +private[spark] case class SparkAppArguments( + mainAppResource: String, + mainClass: String, + appArgs: Array[String]) + +private[spark] object SparkAppLauncher extends Logging { + def launch( + appArguments: SparkAppArguments, + appConf: SparkAppConf, + timeoutSecs: Int, + sparkHomeDir: Path, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { + val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) + logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") + val preCommandLine = if (isJVM) { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--class", appArguments.mainClass, + "--master", appConf.get("spark.master")) + } else { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--master", appConf.get("spark.master")) + } + val commandLine = + pyFiles.map(s => preCommandLine ++ Array("--py-files", s)).getOrElse(preCommandLine) ++ + appConf.toStringArray :+ appArguments.mainAppResource + + if (appArguments.appArgs.nonEmpty) { + commandLine += appArguments.appArgs.mkString(" ") + } + logInfo(s"Launching a spark app with command line: ${commandLine.mkString(" ")}") + ProcessUtils.executeProcess(commandLine.toArray, timeoutSecs) + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala new file mode 100644 index 0000000000000..d8f3a6cec05c3 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.ArrayBuffer +import scala.io.Source + +import org.apache.spark.internal.Logging + +object ProcessUtils extends Logging { + /** + * executeProcess is used to run a command and return the output if it + * completes within timeout seconds. + */ + def executeProcess(fullCommand: Array[String], timeout: Long): Seq[String] = { + val pb = new ProcessBuilder().command(fullCommand: _*) + pb.redirectErrorStream(true) + val proc = pb.start() + val outputLines = new ArrayBuffer[String] + Utils.tryWithResource(proc.getInputStream)( + Source.fromInputStream(_, "UTF-8").getLines().foreach { line => + logInfo(line) + outputLines += line + }) + assert(proc.waitFor(timeout, TimeUnit.SECONDS), + s"Timed out while executing ${fullCommand.mkString(" ")}") + assert(proc.exitValue == 0, s"Failed to execute ${fullCommand.mkString(" ")}") + outputLines + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala new file mode 100644 index 0000000000000..1ebb30094dcde --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} + +private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => + + import PythonTestsSuite._ + import KubernetesSuite.k8sTestTag + + test("Run PySpark on simple pi.py example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_PI, + mainClass = "", + expectedLogOnCompletion = Seq("Pi is roughly 3"), + appArgs = Array("5"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false) + } + + test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonVersion", "2") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonVersion", "3") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python3"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } +} + +private[spark] object PythonTestsSuite { + val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" + val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" + val PYSPARK_FILES: String = CONTAINER_LOCAL_PYSPARK + "pyfiles.py" + val PYSPARK_CONTAINER_TESTS: String = CONTAINER_LOCAL_PYSPARK + "py_container_checks.py" +} + diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala new file mode 100644 index 0000000000000..885a23cfb4864 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} + +private[spark] trait RTestsSuite { k8sSuite: KubernetesSuite => + + import RTestsSuite._ + import KubernetesSuite.k8sTestTag + + test("Run SparkR on simple dataframe.R example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-r:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = SPARK_R_DATAFRAME_TEST, + mainClass = "", + expectedLogOnCompletion = Seq("name: string (nullable = true)", "1 Justin"), + appArgs = Array.empty[String], + driverPodChecker = doBasicDriverRPodCheck, + executorPodChecker = doBasicExecutorRPodCheck, + appLocator = appLocator, + isJVM = false) + } +} + +private[spark] object RTestsSuite { + val CONTAINER_LOCAL_SPARKR: String = "local:///opt/spark/examples/src/main/r/" + val SPARK_R_DATAFRAME_TEST: String = CONTAINER_LOCAL_SPARKR + "dataframe.R" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala new file mode 100644 index 0000000000000..7b05c1355ca24 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Pod, Secret, SecretBuilder} +import org.apache.commons.codec.binary.Base64 +import org.apache.commons.io.output.ByteArrayOutputStream +import org.scalatest.concurrent.Eventually + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ + +private[spark] trait SecretsTestsSuite { k8sSuite: KubernetesSuite => + + import SecretsTestsSuite._ + + private def createTestSecret(): Unit = { + val sb = new SecretBuilder() + sb.withNewMetadata() + .withName(ENV_SECRET_NAME) + .endMetadata() + val secUsername = Base64.encodeBase64String(ENV_SECRET_VALUE_1.getBytes()) + val secPassword = Base64.encodeBase64String(ENV_SECRET_VALUE_2.getBytes()) + val envSecretData = Map(ENV_SECRET_KEY_1 -> secUsername, ENV_SECRET_KEY_2 -> secPassword) + sb.addToData(envSecretData.asJava) + val envSecret = sb.build() + val sec = kubernetesTestComponents + .kubernetesClient + .secrets() + .createOrReplace(envSecret) + } + + private def deleteTestSecret(): Unit = { + kubernetesTestComponents + .kubernetesClient + .secrets() + .withName(ENV_SECRET_NAME) + .delete() + } + + test("Run SparkPi with env and mount secrets.", k8sTestTag) { + createTestSecret() + sparkAppConf + .set(s"spark.kubernetes.driver.secrets.$ENV_SECRET_NAME", SECRET_MOUNT_PATH) + .set(s"spark.kubernetes.driver.secretKeyRef.USERNAME", s"$ENV_SECRET_NAME:username") + .set(s"spark.kubernetes.driver.secretKeyRef.PASSWORD", s"$ENV_SECRET_NAME:password") + .set(s"spark.kubernetes.executor.secrets.$ENV_SECRET_NAME", SECRET_MOUNT_PATH) + .set(s"spark.kubernetes.executor.secretKeyRef.USERNAME", s"$ENV_SECRET_NAME:username") + .set(s"spark.kubernetes.executor.secretKeyRef.PASSWORD", s"$ENV_SECRET_NAME:password") + try { + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkSecrets(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkSecrets(executorPod) + }, + appArgs = Array("1000") // give it enough time for all execs to be visible + ) + } finally { + // make sure this always run + deleteTestSecret() + } + } + + private def checkSecrets(pod: Pod): Unit = { + Eventually.eventually(TIMEOUT, INTERVAL) { + implicit val podName: String = pod.getMetadata.getName + val env = executeCommand("env") + assert(env.toString.contains(ENV_SECRET_VALUE_1)) + assert(env.toString.contains(ENV_SECRET_VALUE_2)) + val fileUsernameContents = executeCommand("cat", s"$SECRET_MOUNT_PATH/$ENV_SECRET_KEY_1") + val filePasswordContents = executeCommand("cat", s"$SECRET_MOUNT_PATH/$ENV_SECRET_KEY_2") + assert(fileUsernameContents.toString.trim.equals(ENV_SECRET_VALUE_1)) + assert(filePasswordContents.toString.trim.equals(ENV_SECRET_VALUE_2)) + } + } + + private def executeCommand(cmd: String*)(implicit podName: String): String = { + val out = new ByteArrayOutputStream() + val watch = kubernetesTestComponents + .kubernetesClient + .pods() + .withName(podName) + .readingInput(System.in) + .writingOutput(out) + .writingError(System.err) + .withTTY() + .exec(cmd.toArray: _*) + // wait to get some result back + Thread.sleep(1000) + watch.close() + out.flush() + out.toString() + } +} + +private[spark] object SecretsTestsSuite { + val ENV_SECRET_NAME = "mysecret" + val SECRET_MOUNT_PATH = "/etc/secret" + val ENV_SECRET_KEY_1 = "username" + val ENV_SECRET_KEY_2 = "password" + val ENV_SECRET_VALUE_1 = "secretusername" + val ENV_SECRET_VALUE_2 = "secretpassword" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala new file mode 100644 index 0000000000000..f1fd6dc19ce54 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.util.concurrent.TimeUnit + +import com.google.common.util.concurrent.SettableFuture +import io.fabric8.kubernetes.api.model.HasMetadata +import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.internal.readiness.Readiness + +private[spark] class SparkReadinessWatcher[T <: HasMetadata] extends Watcher[T] { + + private val signal = SettableFuture.create[Boolean] + + override def eventReceived(action: Action, resource: T): Unit = { + if ((action == Action.MODIFIED || action == Action.ADDED) && + Readiness.isReady(resource)) { + signal.set(true) + } + } + + override def onClose(cause: KubernetesClientException): Unit = {} + + def waitUntilReady(): Boolean = signal.get(60, TimeUnit.SECONDS) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala new file mode 100644 index 0000000000000..5a49e0779160c --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files + +object TestConfig { + def getTestImageTag: String = { + val imageTagFileProp = System.getProperty("spark.kubernetes.test.imageTagFile") + require(imageTagFileProp != null, "Image tag file must be provided in system properties.") + val imageTagFile = new File(imageTagFileProp) + require(imageTagFile.isFile, s"No file found for image tag at ${imageTagFile.getAbsolutePath}.") + Files.toString(imageTagFile, Charsets.UTF_8).trim + } + + def getTestImageRepo: String = { + val imageRepo = System.getProperty("spark.kubernetes.test.imageRepo") + require(imageRepo != null, "Image repo must be provided in system properties.") + imageRepo + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala new file mode 100644 index 0000000000000..8595d0eab1126 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +object TestConstants { + val MINIKUBE_TEST_BACKEND = "minikube" + val GCE_TEST_BACKEND = "gce" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala new file mode 100644 index 0000000000000..663f8b6523ac8 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.Closeable +import java.net.URI + +import org.apache.spark.internal.Logging + +object Utils extends Logging { + + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala new file mode 100644 index 0000000000000..284712c6d250e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s.integrationtest.backend + +import io.fabric8.kubernetes.client.DefaultKubernetesClient + +import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.MinikubeTestBackend + +private[spark] trait IntegrationTestBackend { + def initialize(): Unit + def getKubernetesClient: DefaultKubernetesClient + def cleanUp(): Unit = {} +} + +private[spark] object IntegrationTestBackendFactory { + val deployModeConfigKey = "spark.kubernetes.test.deployMode" + + def getTestBackend: IntegrationTestBackend = { + val deployMode = Option(System.getProperty(deployModeConfigKey)) + .getOrElse("minikube") + if (deployMode == "minikube") { + MinikubeTestBackend + } else { + throw new IllegalArgumentException( + "Invalid " + deployModeConfigKey + ": " + deployMode) + } + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala new file mode 100644 index 0000000000000..6494cbc18f33e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.minikube + +import java.io.File +import java.nio.file.Paths + +import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient} + +import org.apache.spark.deploy.k8s.integrationtest.ProcessUtils +import org.apache.spark.internal.Logging + +// TODO support windows +private[spark] object Minikube extends Logging { + + private val MINIKUBE_STARTUP_TIMEOUT_SECONDS = 60 + + def getMinikubeIp: String = { + val outputs = executeMinikube("ip") + .filter(_.matches("^\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$")) + assert(outputs.size == 1, "Unexpected amount of output from minikube ip") + outputs.head + } + + def getMinikubeStatus: MinikubeStatus.Value = { + val statusString = executeMinikube("status") + .filter(line => line.contains("minikubeVM: ") || line.contains("minikube:")) + .head + .replaceFirst("minikubeVM: ", "") + .replaceFirst("minikube: ", "") + MinikubeStatus.unapply(statusString) + .getOrElse(throw new IllegalStateException(s"Unknown status $statusString")) + } + + def getKubernetesClient: DefaultKubernetesClient = { + val kubernetesMaster = s"https://${getMinikubeIp}:8443" + val userHome = System.getProperty("user.home") + val kubernetesConf = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(kubernetesMaster) + .withCaCertFile(Paths.get(userHome, ".minikube", "ca.crt").toFile.getAbsolutePath) + .withClientCertFile(Paths.get(userHome, ".minikube", "apiserver.crt").toFile.getAbsolutePath) + .withClientKeyFile(Paths.get(userHome, ".minikube", "apiserver.key").toFile.getAbsolutePath) + .build() + new DefaultKubernetesClient(kubernetesConf) + } + + private def executeMinikube(action: String, args: String*): Seq[String] = { + ProcessUtils.executeProcess( + Array("bash", "-c", s"minikube $action") ++ args, MINIKUBE_STARTUP_TIMEOUT_SECONDS) + } +} + +private[spark] object MinikubeStatus extends Enumeration { + + // The following states are listed according to + // https://github.com/docker/machine/blob/master/libmachine/state/state.go. + val STARTING = status("Starting") + val RUNNING = status("Running") + val PAUSED = status("Paused") + val STOPPING = status("Stopping") + val STOPPED = status("Stopped") + val ERROR = status("Error") + val TIMEOUT = status("Timeout") + val SAVED = status("Saved") + val NONE = status("") + + def status(value: String): Value = new Val(nextId, value) + def unapply(s: String): Option[Value] = values.find(s == _.toString) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala new file mode 100644 index 0000000000000..cb9324179d70e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.minikube + +import io.fabric8.kubernetes.client.DefaultKubernetesClient + +import org.apache.spark.deploy.k8s.integrationtest.backend.IntegrationTestBackend + +private[spark] object MinikubeTestBackend extends IntegrationTestBackend { + + private var defaultClient: DefaultKubernetesClient = _ + + override def initialize(): Unit = { + val minikubeStatus = Minikube.getMinikubeStatus + require(minikubeStatus == MinikubeStatus.RUNNING, + s"Minikube must be running to use the Minikube backend for integration tests." + + s" Current status is: $minikubeStatus.") + defaultClient = Minikube.getKubernetesClient + } + + override def cleanUp(): Unit = { + super.cleanUp() + } + + override def getKubernetesClient: DefaultKubernetesClient = { + defaultClient + } +} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index ccf33e8d4283c..64698b55c6bb6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -51,6 +51,14 @@ private[mesos] class MesosClusterDispatcher( conf: SparkConf) extends Logging { + { + // This doesn't support authentication because the RestSubmissionServer doesn't support it. + val authKey = SecurityManager.SPARK_AUTH_SECRET_CONF + require(conf.getOption(authKey).isEmpty, + s"The MesosClusterDispatcher does not support authentication via ${authKey}. It is not " + + s"currently possible to run jobs in cluster mode with authentication on.") + } + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) private val recoveryMode = conf.get(RECOVERY_MODE).toUpperCase() logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index 022191d0070fd..91f64141e5318 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -39,7 +39,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")

      Cannot find driver {driverId}

      - return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + return UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } val driverState = state.get val driverHeaders = Seq("Driver property", "Value") @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

      Driver state information for driver id {driverId}

      - Back to Drivers + Back to Drivers

      Driver state: {driverState.state}

      @@ -87,7 +87,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
      ; - UIUtils.basicSparkPage(content, s"Details for Job $driverId") + UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 88a6614d51384..c53285331ea68 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -62,7 +62,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {retryTable}
      ; - UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + UIUtils.basicSparkPage(request, content, "Spark Drivers for Mesos cluster") } private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 604978967d6db..15bbe60d6c8fb 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -40,7 +40,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) attachPage(new DriverPage(this)) - attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR) } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index d224a7325820a..7d80eedcc43ce 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,8 +30,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} -import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.deploy.mesos.config +import org.apache.spark.deploy.mesos.{config, MesosDriverDescription} import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils @@ -418,6 +417,18 @@ private[spark] class MesosClusterScheduler( envBuilder.build() } + private def isContainerLocalAppJar(desc: MesosDriverDescription): Boolean = { + val isLocalJar = desc.jarUrl.startsWith("local://") + val isContainerLocal = desc.conf.getOption("spark.mesos.appJar.local.resolution.mode").exists { + case "container" => true + case "host" => false + case other => + logWarning(s"Unknown spark.mesos.appJar.local.resolution.mode $other, using host.") + false + } + isLocalJar && isContainerLocal + } + private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = { val confUris = List(conf.getOption("spark.mesos.uris"), desc.conf.getOption("spark.mesos.uris"), @@ -425,10 +436,14 @@ private[spark] class MesosClusterScheduler( _.map(_.split(",").map(_.trim)) ).flatten - val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") - - ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => - CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + if (isContainerLocalAppJar(desc)) { + (confUris ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } else { + val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") + ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } } private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { @@ -480,7 +495,14 @@ private[spark] class MesosClusterScheduler( (cmdExecutable, ".") } val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") - val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val primaryResource = { + if (isContainerLocalAppJar(desc)) { + new File(desc.jarUrl.stripPrefix("local://")).toString() + } else { + new File(sandboxPath, desc.jarUrl.split("/").last).toString() + } + } + val appArguments = desc.command.arguments.mkString(" ") s"$executable $cmdOptions $primaryResource $appArguments" @@ -530,9 +552,9 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } + options ++= Seq("--conf", s"${key}=${value}") } - options + options.map(shellEscape) } /** diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 53f5f61cca486..1ce2f816dffb2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -227,7 +227,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") + val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, taskId) + }.getOrElse("") // Set the environment variable through a command prefix // to append to the existing value of the variable @@ -632,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slave.hostname, externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"), + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index d6d939d246109..0bb6fe0fa4bdf 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -111,7 +111,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, execId) + }.getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => Utils.libraryPathEnvPrefix(Seq(p)) @@ -451,4 +453,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( super.applicationId } + override def maxNumConcurrentTasks(): Int = { + // TODO SPARK-25074 support this method for MesosFineGrainedSchedulerBackend + 0 + } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index bfb73611f0530..b4364a5e2eb3a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -117,7 +117,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { case Array(key, value) => Some(param.setKey(key).setValue(value)) case spec => - logWarning(s"Unable to parse arbitary parameters: $params. " + logWarning(s"Unable to parse arbitrary parameters: $params. " + "Expected form: \"key=value(, ...)\"") None } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index ecbcc960fc5a0..8ef1e18f83de3 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -355,7 +355,7 @@ trait MesosSchedulerUtils extends Logging { * https://github.com/apache/mesos/blob/master/src/common/values.cpp * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp * - * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * @param constraintsVal constains string consisting of ';' separated key-value pairs (separated * by ':') * @return Map of constraints to match resources offers. */ diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala index 33e7d69d53d38..057c51db455ef 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.deploy.TestPrematureExit class MesosClusterDispatcherArgumentsSuite extends SparkFunSuite with TestPrematureExit { - test("test if spark config args are passed sucessfully") { + test("test if spark config args are passed successfully") { val args = Array[String]("--master", "mesos://localhost:5050", "--conf", "key1=value1", "--conf", "spark.mesos.key2=value2", "--verbose") val conf = new SparkConf() diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f4bd1ee9da6f7..b790c7cd27794 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -789,6 +789,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.nodeBlacklist).thenReturn(Set[String]()) when(taskScheduler.sc).thenReturn(sc) externalShuffleClient = mock[MesosExternalShuffleClient] diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 2d2f90c63a309..31f84310485a0 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -253,6 +253,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), @@ -361,6 +362,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index d04989e138f83..8f94e3f731007 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -18,8 +18,8 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} -import java.lang.reflect.InvocationTargetException -import java.net.{Socket, URI, URL} +import java.lang.reflect.{InvocationTargetException, Modifier} +import java.net.{URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -28,6 +28,7 @@ import scala.concurrent.Promise import scala.concurrent.duration.Duration import scala.util.control.NonFatal +import org.apache.commons.lang3.{StringUtils => ComStrUtils} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ @@ -43,6 +44,7 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.AMCredentialRenewer import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -67,6 +69,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends private val securityMgr = new SecurityManager(sparkConf) + private var metricsSystem: Option[MetricsSystem] = None + // Set system properties for each config entry. This covers two use cases: // - The default configuration stored by the SparkHadoopUtil class // - The user application creating a new SparkConf in cluster mode @@ -308,7 +312,17 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e) + "Uncaught exception: " + StringUtils.stringifyException(e)) + } finally { + try { + metricsSystem.foreach { ms => + ms.report() + ms.stop() + } + } catch { + case e: Exception => + logWarning("Exception during stopping of the metric system: ", e) + } } } @@ -346,7 +360,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends synchronized { if (!finished) { val inShutdown = ShutdownHookManager.inShutdown() - if (registered) { + if (registered || !isClusterMode) { exitCode = code finalStatus = status } else { @@ -355,7 +369,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } logInfo(s"Final app status: $finalStatus, exitCode: $exitCode" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - finalMsg = msg + finalMsg = ComStrUtils.abbreviate(msg, sparkConf.get(AM_FINAL_MSG_LIMIT).toInt) finished = true if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { logDebug("shutting down reporter thread") @@ -389,37 +403,40 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def registerAM( + host: String, + port: Int, _sparkConf: SparkConf, - _rpcEnv: RpcEnv, - driverRef: RpcEndpointRef, - uiAddress: Option[String]) = { + uiAddress: Option[String]): Unit = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = ApplicationMaster .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) - val driverUrl = RpcEndpointAddress( - _sparkConf.get("spark.driver.host"), - _sparkConf.get("spark.driver.port").toInt, + client.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress) + registered = true + } + + private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf): Unit = { + val appId = client.getAttemptId().getApplicationId().toString() + val driverUrl = RpcEndpointAddress(driverRef.address.host, driverRef.address.port, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString // Before we initialize the allocator, let's log the information about how executors will // be run up front, to avoid printing this out for every single executor being launched. // Use placeholders for information that changes such as executor IDs. logInfo { - val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt - val executorCores = sparkConf.get(EXECUTOR_CORES) - val dummyRunner = new ExecutorRunnable(None, yarnConf, sparkConf, driverUrl, "", + val executorMemory = _sparkConf.get(EXECUTOR_MEMORY).toInt + val executorCores = _sparkConf.get(EXECUTOR_CORES) + val dummyRunner = new ExecutorRunnable(None, yarnConf, _sparkConf, driverUrl, "", "", executorMemory, executorCores, appId, securityMgr, localResources) dummyRunner.launchContextDebugInfo() } - allocator = client.register(driverUrl, - driverRef, + allocator = client.createAllocator( yarnConf, _sparkConf, - uiAddress, - historyAddress, + driverUrl, + driverRef, securityMgr, localResources) @@ -431,18 +448,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverRef)) allocator.allocateResources() + val ms = MetricsSystem.createMetricsSystem("applicationMaster", sparkConf, securityMgr) + val prefix = _sparkConf.get(YARN_METRICS_NAMESPACE).getOrElse(appId) + ms.registerSource(new ApplicationMasterSource(prefix, allocator)) + ms.start() + metricsSystem = Some(ms) reporterThread = launchReporterThread() } - /** - * @return An [[RpcEndpoint]] that communicates with the driver's scheduler backend. - */ - private def createSchedulerRef(host: String, port: String): RpcEndpointRef = { - rpcEnv.setupEndpointRef( - RpcAddress(host, port.toInt), - YarnSchedulerBackend.ENDPOINT_NAME) - } - private def runDriver(): Unit = { addAmIpFilter(None) userClassThread = startUserApplication() @@ -456,11 +469,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Duration(totalWaitTime, TimeUnit.MILLISECONDS)) if (sc != null) { rpcEnv = sc.env.rpcEnv - val driverRef = createSchedulerRef( - sc.getConf.get("spark.driver.host"), - sc.getConf.get("spark.driver.port")) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl)) - registered = true + + val userConf = sc.getConf + val host = userConf.get("spark.driver.host") + val port = userConf.get("spark.driver.port").toInt + registerAM(host, port, userConf, sc.ui.map(_.webUrl)) + + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(host, port), + YarnSchedulerBackend.ENDPOINT_NAME) + createAllocator(driverRef, userConf) } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -486,10 +504,18 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val amCores = sparkConf.get(AM_CORES) rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr, amCores, true) - val driverRef = waitForSparkDriver() + + // The client-mode AM doesn't listen for incoming connections, so report an invalid port. + registerAM(hostname, -1, sparkConf, sparkConf.getOption("spark.driver.appUIAddress")) + + // The driver should be up and listening, so unlike cluster mode, just try to connect to it + // with no waiting or retrying. + val (driverHost, driverPort) = Utils.parseHostPort(args.userArgs(0)) + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(driverHost, driverPort), + YarnSchedulerBackend.ENDPOINT_NAME) addAmIpFilter(Some(driverRef)) - registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress")) - registered = true + createAllocator(driverRef, sparkConf) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -508,6 +534,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, s"Max number of executor failures ($maxNumExecutorFailures) reached") + } else if (allocator.isAllNodeBlacklisted) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Due to executor failures all available nodes are blacklisted") } else { logDebug("Sending progress") allocator.allocateResources() @@ -600,40 +630,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private def waitForSparkDriver(): RpcEndpointRef = { - logInfo("Waiting for Spark driver to be reachable.") - var driverUp = false - val hostport = args.userArgs(0) - val (driverHost, driverPort) = Utils.parseHostPort(hostport) - - // Spark driver should already be up since it launched us, but we don't want to - // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME) - val deadline = System.currentTimeMillis + totalWaitTimeMs - - while (!driverUp && !finished && System.currentTimeMillis < deadline) { - try { - val socket = new Socket(driverHost, driverPort) - socket.close() - logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) - driverUp = true - } catch { - case e: Exception => - logError("Failed to connect to driver at %s:%s, retrying ...". - format(driverHost, driverPort)) - Thread.sleep(100L) - } - } - - if (!driverUp) { - throw new SparkException("Failed to connect to driver!") - } - - sparkConf.set("spark.driver.host", driverHost) - sparkConf.set("spark.driver.port", driverPort.toString) - createSchedulerRef(driverHost, driverPort.toString) - } - /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter(driver: Option[RpcEndpointRef]) = { val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) @@ -675,9 +671,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val userThread = new Thread { override def run() { try { - mainMethod.invoke(null, userArgs.toArray) - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - logDebug("Done running users class") + if (!Modifier.isStatic(mainMethod.getModifiers)) { + logError(s"Could not find static main method in object ${args.userClass}") + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS) + } else { + mainMethod.invoke(null, userArgs.toArray) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running user class") + } } catch { case e: InvocationTargetException => e.getCause match { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala new file mode 100644 index 0000000000000..0fec916582602 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[spark] class ApplicationMasterSource(prefix: String, yarnAllocator: YarnAllocator) + extends Source { + + override val sourceName: String = prefix + ".applicationMaster" + override val metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("numExecutorsFailed"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumExecutorsFailed + }) + + metricRegistry.register(MetricRegistry.name("numExecutorsRunning"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumExecutorsRunning + }) + + metricRegistry.register(MetricRegistry.name("numReleasedContainers"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumReleasedContainers + }) + + metricRegistry.register(MetricRegistry.name("numLocalityAwareTasks"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.numLocalityAwareTasks + }) + + metricRegistry.register(MetricRegistry.name("numContainersPendingAllocate"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.numContainersPendingAllocate + }) + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5763c3dbc5a8a..4a85898ef880b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -91,6 +91,13 @@ private[spark] class Client( private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt + private val isPython = sparkConf.get(IS_PYTHON_APP) + private val pysparkWorkerMemory: Int = if (isPython) { + sparkConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + } else { + 0 + } + private val distCacheMgr = new ClientDistributedCacheManager() private val principal = sparkConf.get(PRINCIPAL).orNull @@ -333,18 +340,19 @@ private[spark] class Client( val maxMem = newAppResponse.getMaximumResourceCapability().getMemory() logInfo("Verifying our application has not requested more than the maximum " + s"memory capability of the cluster ($maxMem MB per container)") - val executorMem = executorMemory + executorMemoryOverhead + val executorMem = executorMemory + executorMemoryOverhead + pysparkWorkerMemory if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory ($executorMemory" + - s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + - "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + - "'yarn.nodemanager.resource.memory-mb'.") + throw new IllegalArgumentException(s"Required executor memory ($executorMemory), overhead " + + s"($executorMemoryOverhead MB), and PySpark memory ($pysparkWorkerMemory MB) is above " + + s"the max threshold ($maxMem MB) of this cluster! Please check the values of " + + s"'yarn.scheduler.maximum-allocation-mb' and/or 'yarn.nodemanager.resource.memory-mb'.") } val amMem = amMemory + amMemoryOverhead if (amMem > maxMem) { throw new IllegalArgumentException(s"Required AM memory ($amMemory" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + - "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") + "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + + "'yarn.nodemanager.resource.memory-mb'.") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( amMem, @@ -437,7 +445,7 @@ private[spark] class Client( } } - /** + /* * Distribute a file to the cluster. * * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied @@ -811,10 +819,12 @@ private[spark] class Client( // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. if (pythonPath.nonEmpty) { - val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + val pythonPathList = (sys.env.get("PYTHONPATH") ++ pythonPath) + env("PYTHONPATH") = (env.get("PYTHONPATH") ++ pythonPathList) .mkString(ApplicationConstants.CLASS_PATH_SEPARATOR) - env("PYTHONPATH") = pythonPathStr - sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) + val pythonPathExecutorEnv = (sparkConf.getExecutorEnv.toMap.get("PYTHONPATH") ++ + pythonPathList).mkString(ApplicationConstants.CLASS_PATH_SEPARATOR) + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathExecutorEnv) } if (isClusterMode) { @@ -892,12 +902,15 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) + prefixEnv = Some(createLibraryPathPrefix(libraryPaths.mkString(File.pathSeparator), + sparkConf)) } if (sparkConf.get(AM_JAVA_OPTIONS).isDefined) { logWarning(s"${AM_JAVA_OPTIONS.key} will not take effect in cluster mode") @@ -914,10 +927,12 @@ private[spark] class Client( s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) + prefixEnv = Some(createLibraryPathPrefix(paths, sparkConf)) } } @@ -1015,8 +1030,7 @@ private[spark] class Client( appId: ApplicationId, returnOnRunning: Boolean = false, logApplicationReport: Boolean = true, - interval: Long = sparkConf.get(REPORT_INTERVAL)): - (YarnApplicationState, FinalApplicationStatus) = { + interval: Long = sparkConf.get(REPORT_INTERVAL)): YarnAppReport = { var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) @@ -1027,11 +1041,13 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") cleanupStagingDir(appId) - return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + return YarnAppReport(YarnApplicationState.KILLED, FinalApplicationStatus.KILLED, None) case NonFatal(e) => - logError(s"Failed to contact YARN for application $appId.", e) + val msg = s"Failed to contact YARN for application $appId." + logError(msg, e) // Don't necessarily clean up staging dir because status is unknown - return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) + return YarnAppReport(YarnApplicationState.FAILED, FinalApplicationStatus.FAILED, + Some(msg)) } val state = report.getYarnApplicationState @@ -1069,14 +1085,14 @@ private[spark] class Client( } if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { cleanupStagingDir(appId) - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } if (returnOnRunning && state == YarnApplicationState.RUNNING) { - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } lastState = state @@ -1125,16 +1141,17 @@ private[spark] class Client( throw new SparkException(s"Application $appId finished with status: $state") } } else { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { + val YarnAppReport(appState, finalState, diags) = monitorApplication(appId) + if (appState == YarnApplicationState.FAILED || finalState == FinalApplicationStatus.FAILED) { + diags.foreach { err => + logError(s"Application diagnostics message: $err") + } throw new SparkException(s"Application $appId finished with failed status") } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { + if (appState == YarnApplicationState.KILLED || finalState == FinalApplicationStatus.KILLED) { throw new SparkException(s"Application $appId is killed") } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + if (finalState == FinalApplicationStatus.UNDEFINED) { throw new SparkException(s"The final status of application $appId is undefined") } } @@ -1148,7 +1165,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) @@ -1473,6 +1490,29 @@ private object Client extends Logging { uri.startsWith(s"$LOCAL_SCHEME:") } + def createAppReport(report: ApplicationReport): YarnAppReport = { + val diags = report.getDiagnostics() + val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None + YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt) + } + + /** + * Create a properly quoted and escaped library path string to be added as a prefix to the command + * executed by YARN. This is different from normal quoting / escaping due to YARN executing the + * command through "bash -c". + */ + def createLibraryPathPrefix(libpath: String, conf: SparkConf): String = { + val cmdPrefix = if (Utils.isWindows) { + Utils.libraryPathEnvPrefix(Seq(libpath)) + } else { + val envName = Utils.libraryPathEnvName + // For quotes, escape both the quote and the escape character when encoding in the command + // string. + val quoted = libpath.replace("\"", "\\\\\\\"") + envName + "=\\\"" + quoted + File.pathSeparator + "$" + envName + "\\\"" + } + getClusterPath(conf, cmdPrefix) + } } private[spark] class YarnClusterApplication extends SparkApplication { @@ -1487,3 +1527,8 @@ private[spark] class YarnClusterApplication extends SparkApplication { } } + +private[spark] case class YarnAppReport( + appState: YarnApplicationState, + finalState: FinalApplicationStatus, + diagnostics: Option[String]) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ab08698035c98..49a0b93aa5c40 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -131,20 +131,20 @@ private[yarn] class ExecutorRunnable( // Extra options for the JVM val javaOpts = ListBuffer[String]() - // Set the environment variable through a command prefix - // to append to the existing value of the variable - var prefixEnv: Option[String] = None - // Set the JVM memory val executorMemoryString = executorMemory + "m" javaOpts += "-Xmx" + executorMemoryString // Set extra Java options for the executor, if defined sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) + javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } - sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => - prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) + + // Set the library path through a command prefix to append to the existing value of the + // env variable. + val prefixEnv = sparkConf.get(EXECUTOR_LIBRARY_PATH).map { libPath => + Client.createLibraryPathPrefix(libPath, sparkConf) } javaOpts += "-Djava.io.tmpdir=" + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebee3d431744d..8a7551de7c088 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -24,7 +24,7 @@ import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records._ @@ -66,7 +66,8 @@ private[yarn] class YarnAllocator( appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, localResources: Map[String, LocalResource], - resolver: SparkRackResolver) + resolver: SparkRackResolver, + clock: Clock = new SystemClock) extends Logging { import YarnAllocator._ @@ -102,18 +103,14 @@ private[yarn] class YarnAllocator( private var executorIdCounter: Int = driverRef.askSync[Int](RetrieveLastAllocatedExecutorId) - // Queue to store the timestamp of failed executors - private val failedExecutorsTimeStamps = new Queue[Long]() + private[spark] val failureTracker = new FailureTracker(sparkConf, clock) - private var clock: Clock = new SystemClock - - private val executorFailuresValidityInterval = - sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + private val allocatorBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClient, failureTracker) @volatile private var targetNumExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) - private var currentNodeBlacklist = Set.empty[String] // Executor loss reason requests that are pending - maps from executor ID for inquiry to a // list of requesters that should be responded to once we find out why the given executor @@ -136,10 +133,17 @@ private[yarn] class YarnAllocator( // Additional memory overhead. protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt + protected val pysparkWorkerMemory: Int = if (sparkConf.get(IS_PYTHON_APP)) { + sparkConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + } else { + 0 + } // Number of cores per executor. protected val executorCores = sparkConf.get(EXECUTOR_CORES) // Resource capability requested for each executors - private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + private[yarn] val resource = Resource.newInstance( + executorMemory + memoryOverhead + pysparkWorkerMemory, + executorCores) private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( "ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS)) @@ -149,43 +153,33 @@ private[yarn] class YarnAllocator( private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION) - // A map to store preferred hostname and possible task numbers running on it. private var hostToLocalTaskCounts: Map[String, Int] = Map.empty // Number of tasks that have locality preferences in active stages - private var numLocalityAwareTasks: Int = 0 + private[yarn] var numLocalityAwareTasks: Int = 0 // A container placement strategy based on pending tasks' locality preference private[yarn] val containerPlacementStrategy = new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) - /** - * Use a different clock for YarnAllocator. This is mainly used for testing. - */ - def setClock(newClock: Clock): Unit = { - clock = newClock - } - def getNumExecutorsRunning: Int = runningExecutors.size() - def getNumExecutorsFailed: Int = synchronized { - val endTime = clock.getTimeMillis() + def getNumReleasedContainers: Int = releasedContainers.size() - while (executorFailuresValidityInterval > 0 - && failedExecutorsTimeStamps.nonEmpty - && failedExecutorsTimeStamps.head < endTime - executorFailuresValidityInterval) { - failedExecutorsTimeStamps.dequeue() - } + def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors - failedExecutorsTimeStamps.size - } + def isAllNodeBlacklisted: Boolean = allocatorBlacklistTracker.isAllNodeBlacklisted /** * A sequence of pending container requests that have not yet been fulfilled. */ def getPendingAllocate: Seq[ContainerRequest] = getPendingAtLocation(ANY_HOST) + def numContainersPendingAllocate: Int = synchronized { + getPendingAllocate.size + } + /** * A sequence of pending container requests at the given location that have not yet been * fulfilled. @@ -204,9 +198,8 @@ private[yarn] class YarnAllocator( * @param localityAwareTasks number of locality aware tasks to be used as container placement hint * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as * container placement hint. - * @param nodeBlacklist a set of blacklisted nodes, which is passed in to avoid allocating new - * containers on them. It will be used to update the application master's - * blacklist. + * @param nodeBlacklist blacklisted nodes, which is passed in to avoid allocating new containers + * on them. It will be used to update the application master's blacklist. * @return Whether the new requested total is different than the old value. */ def requestTotalExecutorsWithPreferredLocalities( @@ -220,19 +213,7 @@ private[yarn] class YarnAllocator( if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal - - // Update blacklist infomation to YARN ResouceManager for this application, - // in order to avoid allocating new Containers on the problematic nodes. - val blacklistAdditions = nodeBlacklist -- currentNodeBlacklist - val blacklistRemovals = currentNodeBlacklist -- nodeBlacklist - if (blacklistAdditions.nonEmpty) { - logInfo(s"adding nodes to YARN application master's blacklist: $blacklistAdditions") - } - if (blacklistRemovals.nonEmpty) { - logInfo(s"removing nodes from YARN application master's blacklist: $blacklistRemovals") - } - amClient.updateBlacklist(blacklistAdditions.toList.asJava, blacklistRemovals.toList.asJava) - currentNodeBlacklist = nodeBlacklist + allocatorBlacklistTracker.setSchedulerBlacklistedNodes(nodeBlacklist) true } else { false @@ -268,6 +249,7 @@ private[yarn] class YarnAllocator( val allocateResponse = amClient.allocate(progressIndicator) val allocatedContainers = allocateResponse.getAllocatedContainers() + allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes) if (allocatedContainers.size > 0) { logDebug(("Allocated containers: %d. Current executor count: %d. " + @@ -602,8 +584,9 @@ private[yarn] class YarnAllocator( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) case _ => - // Enqueue the timestamp of failed executor - failedExecutorsTimeStamps.enqueue(clock.getTimeMillis()) + // all the failures which not covered above, like: + // disk failure, kill by app master or resource manager, ... + allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + ". Diagnostics: " + completedContainer.getDiagnostics) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala new file mode 100644 index 0000000000000..ceac7cda5f8be --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.HashMap + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.scheduler.BlacklistTracker +import org.apache.spark.util.{Clock, SystemClock} + +/** + * YarnAllocatorBlacklistTracker is responsible for tracking the blacklisted nodes + * and synchronizing the node list to YARN. + * + * Blacklisted nodes are coming from two different sources: + * + *
        + *
      • from the scheduler as task level blacklisted nodes + *
      • from this class (tracked here) as YARN resource allocation problems + *
      + * + * The reason to realize this logic here (and not in the driver) is to avoid possible delays + * between synchronizing the blacklisted nodes with YARN and resource allocations. + */ +private[spark] class YarnAllocatorBlacklistTracker( + sparkConf: SparkConf, + amClient: AMRMClient[ContainerRequest], + failureTracker: FailureTracker) + extends Logging { + + private val blacklistTimeoutMillis = BlacklistTracker.getBlacklistTimeout(sparkConf) + + private val launchBlacklistEnabled = sparkConf.get(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED) + + private val maxFailuresPerHost = sparkConf.get(MAX_FAILED_EXEC_PER_NODE) + + private val allocatorBlacklist = new HashMap[String, Long]() + + private var currentBlacklistedYarnNodes = Set.empty[String] + + private var schedulerBlacklist = Set.empty[String] + + private var numClusterNodes = Int.MaxValue + + def setNumClusterNodes(numClusterNodes: Int): Unit = { + this.numClusterNodes = numClusterNodes + } + + def handleResourceAllocationFailure(hostOpt: Option[String]): Unit = { + hostOpt match { + case Some(hostname) if launchBlacklistEnabled => + // failures on an already blacklisted nodes are not even tracked. + // otherwise, such failures could shutdown the application + // as resource requests are asynchronous + // and a late failure response could exceed MAX_EXECUTOR_FAILURES + if (!schedulerBlacklist.contains(hostname) && + !allocatorBlacklist.contains(hostname)) { + failureTracker.registerFailureOnHost(hostname) + updateAllocationBlacklistedNodes(hostname) + } + case _ => + failureTracker.registerExecutorFailure() + } + } + + private def updateAllocationBlacklistedNodes(hostname: String): Unit = { + val failuresOnHost = failureTracker.numFailuresOnHost(hostname) + if (failuresOnHost > maxFailuresPerHost) { + logInfo(s"blacklisting $hostname as YARN allocation failed $failuresOnHost times") + allocatorBlacklist.put( + hostname, + failureTracker.clock.getTimeMillis() + blacklistTimeoutMillis) + refreshBlacklistedNodes() + } + } + + def setSchedulerBlacklistedNodes(schedulerBlacklistedNodesWithExpiry: Set[String]): Unit = { + this.schedulerBlacklist = schedulerBlacklistedNodesWithExpiry + refreshBlacklistedNodes() + } + + def isAllNodeBlacklisted: Boolean = currentBlacklistedYarnNodes.size >= numClusterNodes + + private def refreshBlacklistedNodes(): Unit = { + removeExpiredYarnBlacklistedNodes() + val allBlacklistedNodes = schedulerBlacklist ++ allocatorBlacklist.keySet + synchronizeBlacklistedNodeWithYarn(allBlacklistedNodes) + } + + private def synchronizeBlacklistedNodeWithYarn(nodesToBlacklist: Set[String]): Unit = { + // Update blacklist information to YARN ResourceManager for this application, + // in order to avoid allocating new Containers on the problematic nodes. + val additions = (nodesToBlacklist -- currentBlacklistedYarnNodes).toList.sorted + val removals = (currentBlacklistedYarnNodes -- nodesToBlacklist).toList.sorted + if (additions.nonEmpty) { + logInfo(s"adding nodes to YARN application master's blacklist: $additions") + } + if (removals.nonEmpty) { + logInfo(s"removing nodes from YARN application master's blacklist: $removals") + } + amClient.updateBlacklist(additions.asJava, removals.asJava) + currentBlacklistedYarnNodes = nodesToBlacklist + } + + private def removeExpiredYarnBlacklistedNodes(): Unit = { + val now = failureTracker.clock.getTimeMillis() + allocatorBlacklist.retain { (_, expiryTime) => expiryTime > now } + } +} + +/** + * FailureTracker is responsible for tracking executor failures both for each host separately + * and for all hosts altogether. + */ +private[spark] class FailureTracker( + sparkConf: SparkConf, + val clock: Clock = new SystemClock) extends Logging { + + private val executorFailuresValidityInterval = + sparkConf.get(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + + // Queue to store the timestamp of failed executors for each host + private val failedExecutorsTimeStampsPerHost = mutable.Map[String, mutable.Queue[Long]]() + + private val failedExecutorsTimeStamps = new mutable.Queue[Long]() + + private def updateAndCountFailures(failedExecutorsWithTimeStamps: mutable.Queue[Long]): Int = { + val endTime = clock.getTimeMillis() + while (executorFailuresValidityInterval > 0 && + failedExecutorsWithTimeStamps.nonEmpty && + failedExecutorsWithTimeStamps.head < endTime - executorFailuresValidityInterval) { + failedExecutorsWithTimeStamps.dequeue() + } + failedExecutorsWithTimeStamps.size + } + + def numFailedExecutors: Int = synchronized { + updateAndCountFailures(failedExecutorsTimeStamps) + } + + def registerFailureOnHost(hostname: String): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + val failedExecutorsOnHost = + failedExecutorsTimeStampsPerHost.getOrElse(hostname, { + val failureOnHost = mutable.Queue[Long]() + failedExecutorsTimeStampsPerHost.put(hostname, failureOnHost) + failureOnHost + }) + failedExecutorsOnHost.enqueue(timeMillis) + } + + def registerExecutorFailure(): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + } + + def numFailuresOnHost(hostname: String): Int = { + failedExecutorsTimeStampsPerHost.get(hostname).map { failedExecutorsOnHost => + updateAndCountFailures(failedExecutorsOnHost) + }.getOrElse(0) + } + +} + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 17234b120ae13..05a7b1e1310c4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.yarn import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest @@ -42,23 +43,20 @@ private[spark] class YarnRMClient extends Logging { /** * Registers the application master with the RM. * + * @param driverHost Host name where driver is running. + * @param driverPort Port where driver is listening. * @param conf The Yarn configuration. * @param sparkConf The Spark configuration. * @param uiAddress Address of the SparkUI. * @param uiHistoryAddress Address of the application on the History Server. - * @param securityMgr The security manager. - * @param localResources Map with information about files distributed via YARN's cache. */ def register( - driverUrl: String, - driverRef: RpcEndpointRef, + driverHost: String, + driverPort: Int, conf: YarnConfiguration, sparkConf: SparkConf, uiAddress: Option[String], - uiHistoryAddress: String, - securityMgr: SecurityManager, - localResources: Map[String, LocalResource] - ): YarnAllocator = { + uiHistoryAddress: String): Unit = { amClient = AMRMClient.createAMRMClient() amClient.init(conf) amClient.start() @@ -70,10 +68,19 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, - trackingUrl) + amClient.registerApplicationMaster(driverHost, driverPort, trackingUrl) registered = true } + } + + def createAllocator( + conf: YarnConfiguration, + sparkConf: SparkConf, + driverUrl: String, + driverRef: RpcEndpointRef, + securityMgr: SecurityManager, + localResources: Map[String, LocalResource]): YarnAllocator = { + require(registered, "Must register AM before creating allocator.") new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, localResources, new SparkRackResolver()) } @@ -88,6 +95,9 @@ private[spark] class YarnRMClient extends Logging { if (registered) { amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } + if (amClient != null) { + amClient.stop() + } } /** Returns the attempt ID. */ @@ -103,7 +113,16 @@ private[spark] class YarnRMClient extends Logging { val proxies = WebAppUtils.getProxyHostsAndPortsForAmFilter(conf) val hosts = proxies.asScala.map(_.split(":").head) val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } - Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) + val params = + Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) + + // Handles RM HA urls + val rmIds = conf.getStringCollection(YarnConfiguration.RM_HA_IDS).asScala + if (rmIds != null && rmIds.nonEmpty) { + params + ("RM_HA_URLS" -> rmIds.map(getUrlByRmId(conf, _)).mkString(",")) + } else { + params + } } /** Returns the maximum number of attempts to register the AM. */ @@ -117,4 +136,21 @@ private[spark] class YarnRMClient extends Logging { } } + private def getUrlByRmId(conf: Configuration, rmId: String): String = { + val addressPropertyPrefix = if (YarnConfiguration.useHttps(conf)) { + YarnConfiguration.RM_WEBAPP_HTTPS_ADDRESS + } else { + YarnConfiguration.RM_WEBAPP_ADDRESS + } + + val addressWithRmId = if (rmId == null || rmId.isEmpty) { + addressPropertyPrefix + } else if (rmId.startsWith(".")) { + throw new IllegalStateException(s"rmId $rmId should not already have '.' prepended.") + } else { + s"$addressPropertyPrefix.$rmId" + } + + conf.get(addressWithRmId) + } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 8eda6cb1277c5..3a3272216294f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -27,11 +27,8 @@ import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} import org.apache.hadoop.yarn.util.ConverterUtils -import org.apache.spark.{SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager -import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils @@ -193,14 +190,35 @@ object YarnSparkHadoopUtil { sparkConf: SparkConf, hadoopConf: Configuration): Set[FileSystem] = { val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS) - .map(new Path(_).getFileSystem(hadoopConf)) - .toSet + val requestAllDelegationTokens = filesystemsToAccess.isEmpty val stagingFS = sparkConf.get(STAGING_DIR) .map(new Path(_).getFileSystem(hadoopConf)) .getOrElse(FileSystem.get(hadoopConf)) - filesystemsToAccess + stagingFS + // Add the list of available namenodes for all namespaces in HDFS federation. + // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its + // namespaces. + val hadoopFilesystems = if (!requestAllDelegationTokens || stagingFS.getScheme == "viewfs") { + filesystemsToAccess.map(new Path(_).getFileSystem(hadoopConf)).toSet + } else { + val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices") + // Retrieving the filesystem for the nameservices where HA is not enabled + val filesystemsWithoutHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.namenode.rpc-address.$ns")).map { nameNode => + new Path(s"hdfs://$nameNode").getFileSystem(hadoopConf) + } + } + // Retrieving the filesystem for the nameservices where HA is enabled + val filesystemsWithHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.ha.namenodes.$ns")).map { _ => + new Path(s"hdfs://$ns").getFileSystem(hadoopConf) + } + } + (filesystemsWithoutHA ++ filesystemsWithHA).toSet + } + + hadoopFilesystems + stagingFS } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 1a99b3bd57672..ab8273bd6321d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -152,6 +152,11 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("100s") + private[spark] val YARN_METRICS_NAMESPACE = ConfigBuilder("spark.yarn.metrics.namespace") + .doc("The root namespace for AM metrics reporting.") + .stringConf + .createOptional + private[spark] val AM_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.am.nodeLabelExpression") .doc("Node label expression for the AM.") .stringConf @@ -187,6 +192,12 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val AM_FINAL_MSG_LIMIT = ConfigBuilder("spark.yarn.am.finalMessageLimit") + .doc("The limit size of final diagnostic message for our ApplicationMaster to unregister from" + + " the ResourceManager.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + /* Client-mode AM configuration. */ private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores") @@ -328,4 +339,10 @@ package object config { CACHED_FILES_TYPES, CACHED_CONF_ARCHIVE) + /* YARN allocator-level blacklisting related config entries. */ + private[spark] val YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED = + ConfigBuilder("spark.yarn.blacklist.executor.launch.blacklisting.enabled") + .booleanConf + .createWithDefault(false) + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index d4eeb6bbcf886..26a2e5d730218 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -44,6 +44,10 @@ private[yarn] class YARNHadoopDelegationTokenManager( // public for testing val credentialProviders = getCredentialProviders + if (credentialProviders.nonEmpty) { + logDebug("Using the following YARN-specific credential providers: " + + s"${credentialProviders.keys.mkString(", ")}.") + } /** * Writes delegation tokens to creds. Delegation tokens are fetched from all registered diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 06e54a2eaf95a..9397a1e3de9ac 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnAppReport} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -75,13 +75,23 @@ private[spark] class YarnClientSchedulerBackend( val monitorInterval = conf.get(CLIENT_LAUNCH_MONITOR_INTERVAL) assert(client != null && appId.isDefined, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true, - interval = monitorInterval) // blocking + val YarnAppReport(state, _, diags) = client.monitorApplication(appId.get, + returnOnRunning = true, interval = monitorInterval) if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - throw new SparkException("Yarn application has already ended! " + - "It might have been killed or unable to launch application master.") + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + val genericMessage = "The YARN application has already ended! " + + "It might have been killed or the Application Master may have failed to start. " + + "Check the YARN application logs for more details." + val exceptionMsg = diags match { + case Some(msg) => + logError(genericMessage) + msg + + case None => + genericMessage + } + throw new SparkException(exceptionMsg) } if (state == YarnApplicationState.RUNNING) { logInfo(s"Application ${appId.get} has started running.") @@ -100,8 +110,13 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") + val YarnAppReport(_, state, diags) = + client.monitorApplication(appId.get, logApplicationReport = false) + logError(s"YARN application has exited unexpectedly with state $state! " + + "Check the YARN application logs for more details.") + diags.foreach { err => + logError(s"Diagnostics message: $err") + } allowInterrupt = false sc.stop() } catch { @@ -124,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend( private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread - t.setName("Yarn application state monitor") + t.setName("YARN application state monitor") t.setDaemon(true) t } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index ac67f2196e0a0..3a7913122dd83 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher._ import org.apache.spark.util.Utils @@ -132,7 +133,8 @@ abstract class BaseYarnClusterSuite extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { + extraEnv: Map[String, String] = Map(), + outFile: Option[File] = None): SparkAppHandle.State = { val deployMode = if (clientMode) "client" else "cluster" val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv @@ -160,6 +162,11 @@ abstract class BaseYarnClusterSuite } extraJars.foreach(launcher.addJar) + if (outFile.isDefined) { + launcher.redirectOutput(outFile.get) + launcher.redirectError() + } + val handle = launcher.startApplication() try { eventually(timeout(2 minutes), interval(1 second)) { @@ -178,17 +185,22 @@ abstract class BaseYarnClusterSuite * the tests enforce that something is written to a file after everything is ok to indicate * that the job succeeded. */ - protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = { - checkResult(finalState, result, "success") - } - protected def checkResult( finalState: SparkAppHandle.State, result: File, - expected: String): Unit = { - finalState should be (SparkAppHandle.State.FINISHED) + expected: String = "success", + outFile: Option[File] = None): Unit = { + // the context message is passed to assert as Any instead of a function. to lazily load the + // output from the file, this passes an anonymous object that loads it in toString when building + // an error message + val output = new Object() { + override def toString: String = outFile + .map(Files.toString(_, StandardCharsets.UTF_8)) + .getOrElse("(stdout/stderr was not captured)") + } + assert(finalState === SparkAppHandle.State.FINISHED, output) val resultString = Files.toString(result, StandardCharsets.UTF_8) - resultString should be (expected) + assert(resultString === expected, output) } protected def mainClassName(klass: Class[_]): String = { @@ -216,6 +228,14 @@ abstract class BaseYarnClusterSuite props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + // SPARK-24446: make sure special characters in the library path do not break containers. + if (!Utils.isWindows) { + val libPath = """/tmp/does not exist:$PWD/tmp:/tmp/quote":/tmp/ampersand&""" + props.setProperty(AM_LIBRARY_PATH.key, libPath) + props.setProperty(DRIVER_LIBRARY_PATH.key, libPath) + props.setProperty(EXECUTOR_LIBRARY_PATH.key, libPath) + } + yarnCluster.getConfig().asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala new file mode 100644 index 0000000000000..4f77b9c99dd25 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.ManualClock + +class FailureTrackerSuite extends SparkFunSuite with Matchers { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("failures expire if validity interval is set") { + val sparkConf = new SparkConf() + sparkConf.set(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS, 100L) + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(101) + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(231) + failureTracker.numFailuresOnHost("host1") should be (0) + failureTracker.numFailuresOnHost("host2") should be (0) + failureTracker.numFailedExecutors should be (0) + } + + + test("failures never expire if validity interval is not set (-1)") { + val sparkConf = new SparkConf() + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(1000) + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + } + +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala new file mode 100644 index 0000000000000..aeac68e6ed330 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.Arrays +import java.util.Collections + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.config.YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED +import org.apache.spark.internal.config.{BLACKLIST_TIMEOUT_CONF, MAX_FAILED_EXEC_PER_NODE} +import org.apache.spark.util.ManualClock + +class YarnAllocatorBlacklistTrackerSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach { + + val BLACKLIST_TIMEOUT = 100L + val MAX_FAILED_EXEC_PER_NODE_VALUE = 2 + + var amClientMock: AMRMClient[ContainerRequest] = _ + var yarnBlacklistTracker: YarnAllocatorBlacklistTracker = _ + var failureTracker: FailureTracker = _ + var clock: ManualClock = _ + + override def beforeEach(): Unit = { + val sparkConf = new SparkConf() + sparkConf.set(BLACKLIST_TIMEOUT_CONF, BLACKLIST_TIMEOUT) + sparkConf.set(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED, true) + sparkConf.set(MAX_FAILED_EXEC_PER_NODE, MAX_FAILED_EXEC_PER_NODE_VALUE) + clock = new ManualClock() + + amClientMock = mock(classOf[AMRMClient[ContainerRequest]]) + failureTracker = new FailureTracker(sparkConf, clock) + yarnBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClientMock, failureTracker) + yarnBlacklistTracker.setNumClusterNodes(4) + super.beforeEach() + } + + test("expiring its own blacklisted nodes") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // host should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + } + } + + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // the third failure on the host triggers the blacklisting + verify(amClientMock).updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + + clock.advance(BLACKLIST_TIMEOUT) + + // trigger synchronisation of blacklisted nodes with YARN + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set()) + verify(amClientMock).updateBlacklist(Collections.emptyList(), Arrays.asList("host")) + } + + test("not handling the expiry of scheduler blacklisted nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2"), Collections.emptyList()) + + // advance timer more then host1, host2 expiry time + clock.advance(200L) + + // expired blacklisted nodes (simulating a resource request) + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + // no change is communicated to YARN regarding the blacklisting + verify(amClientMock).updateBlacklist(Collections.emptyList(), Collections.emptyList()) + } + + test("combining scheduler and allocation blacklist") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + // host1 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + } + } + + // as this is the third failure on host1 the node will be blacklisted + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host2", "host3"), Collections.emptyList()) + + clock.advance(10L) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host3", "host4")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host4"), Arrays.asList("host2")) + } + + test("blacklist all available nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2", "host3"), Collections.emptyList()) + + clock.advance(60L) + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + // host4 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + } + } + + // the third failure on the host triggers the blacklisting + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + + verify(amClientMock).updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + assert(yarnBlacklistTracker.isAllNodeBlacklisted === true) + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 525abb6f2b350..3f783baed110d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -59,6 +59,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var rmClient: AMRMClient[ContainerRequest] = _ + var clock: ManualClock = _ + var containerNum = 0 override def beforeEach() { @@ -66,6 +68,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() + clock = new ManualClock() } override def afterEach() { @@ -101,7 +104,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter appAttemptId, new SecurityManager(sparkConf), Map(), - new MockResolver()) + new MockResolver(), + clock) } def createContainer(host: String): Container = { @@ -332,10 +336,14 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map(), Set("hostA")) verify(mockAmClient).updateBlacklist(Seq("hostA").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), Set("hostA", "hostB")) + val blacklistedNodes = Set( + "hostA", + "hostB" + ) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), blacklistedNodes) verify(mockAmClient).updateBlacklist(Seq("hostB").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set()) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set.empty) verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } @@ -353,8 +361,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) - val clock = new ManualClock(0L) - handler.setClock(clock) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a129be7c06b53..58d11e96942e1 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -108,7 +108,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", "spark.executor.instances" -> "2", - // Sending some senstive information, which we'll make sure gets redacted + // Sending some sensitive information, which we'll make sure gets redacted "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) @@ -265,35 +265,32 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.6-src.zip", + s"$sparkHome/python/lib/py4j-0.10.7-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv - val moduleDir = - if (clientMode) { - // In client-mode, .py files added with --py-files are not visible in the driver. - // This is something that the launcher library would have to handle. - tempDir - } else { - val subdir = new File(tempDir, "pyModules") - subdir.mkdir() - subdir - } + val moduleDir = { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } val pyModule = new File(moduleDir, "mod1.py") Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") val result = File.createTempFile("result", null, tempDir) + val outFile = Some(File.createTempFile("stdout", null, tempDir)) val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(), sparkArgs = Seq("--py-files" -> pyFiles), appArgs = Seq(result.getAbsolutePath()), extraEnv = extraEnvVars, - extraConf = extraConf) - checkResult(finalState, result) + extraConf = extraConf, + outFile = outFile) + checkResult(finalState, result, outFile = outFile) } private def testUseClassPathFirst(clientMode: Boolean): Unit = { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index f21353aa007c8..61c0c43f7c04f 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -21,7 +21,8 @@ import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.io.Text +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -141,4 +142,66 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } + test("SPARK-24149: retrieve all namenodes from HDFS") { + val sparkConf = new SparkConf() + val basicFederationConf = new Configuration() + basicFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + basicFederationConf.set("dfs.nameservices", "ns1,ns2") + basicFederationConf.set("dfs.namenode.rpc-address.ns1", "localhost:8020") + basicFederationConf.set("dfs.namenode.rpc-address.ns2", "localhost:8021") + val basicFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(basicFederationConf), + new Path("hdfs://localhost:8021").getFileSystem(basicFederationConf)) + val basicFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, basicFederationConf) + basicFederationResult should be (basicFederationExpected) + + // when viewfs is enabled, namespaces are handled by it, so we don't need to take care of them + val viewFsConf = new Configuration() + viewFsConf.addResource(basicFederationConf) + viewFsConf.set("fs.defaultFS", "viewfs://clusterX/") + viewFsConf.set("fs.viewfs.mounttable.clusterX.link./home", "hdfs://localhost:8020/") + val viewFsExpected = Set(new Path("viewfs://clusterX/").getFileSystem(viewFsConf)) + YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, viewFsConf) should be (viewFsExpected) + + // invalid config should not throw NullPointerException + val invalidFederationConf = new Configuration() + invalidFederationConf.addResource(basicFederationConf) + invalidFederationConf.unset("dfs.namenode.rpc-address.ns2") + val invalidFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(invalidFederationConf)) + val invalidFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, invalidFederationConf) + invalidFederationResult should be (invalidFederationExpected) + + // no namespaces defined, ie. old case + val noFederationConf = new Configuration() + noFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + val noFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(noFederationConf)) + val noFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, noFederationConf) + noFederationResult should be (noFederationExpected) + + // federation and HA enabled + val federationAndHAConf = new Configuration() + federationAndHAConf.set("fs.defaultFS", "hdfs://clusterXHA") + federationAndHAConf.set("dfs.nameservices", "clusterXHA,clusterYHA") + federationAndHAConf.set("dfs.ha.namenodes.clusterXHA", "x-nn1,x-nn2") + federationAndHAConf.set("dfs.ha.namenodes.clusterYHA", "y-nn1,y-nn2") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn1", "localhost:8020") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn2", "localhost:8021") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn1", "localhost:8022") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn2", "localhost:8023") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterXHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterYHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + + val federationAndHAExpected = Set( + new Path("hdfs://clusterXHA").getFileSystem(federationAndHAConf), + new Path("hdfs://clusterYHA").getFileSystem(federationAndHAConf)) + val federationAndHAResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, federationAndHAConf) + federationAndHAResult should be (federationAndHAExpected) + } } diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index bac154e10ae62..bf3da18c3706e 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e65e3aafe5b5b..da5c3f29c32dc 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -150,6 +150,19 @@ This file is divided into 3 sections: // scalastyle:on println]]> + + spark(.sqlContext)?.sparkContext.hadoopConfiguration + + + @VisibleForTesting ' expression #lambda + | '(' IDENTIFIER (',' IDENTIFIER)+ ')' '->' expression #lambda | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract ; constant @@ -725,7 +746,7 @@ nonReserved | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER + | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP @@ -735,6 +756,7 @@ nonReserved | VIEW | REPLACE | IF | POSITION + | EXTRACT | NO | DATA | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION @@ -745,7 +767,7 @@ nonReserved | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE + | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN @@ -760,6 +782,7 @@ FROM: 'FROM'; ADD: 'ADD'; AS: 'AS'; ALL: 'ALL'; +ANY: 'ANY'; DISTINCT: 'DISTINCT'; WHERE: 'WHERE'; GROUP: 'GROUP'; @@ -805,6 +828,7 @@ RIGHT: 'RIGHT'; FULL: 'FULL'; NATURAL: 'NATURAL'; ON: 'ON'; +PIVOT: 'PIVOT'; LATERAL: 'LATERAL'; WINDOW: 'WINDOW'; OVER: 'OVER'; @@ -872,6 +896,7 @@ TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; +EXTRACT: 'EXTRACT'; EQ : '=' | '=='; NSEQ: '<=>'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d5d934bc91cab..cf2a5ed2e27f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -83,7 +83,7 @@ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elemen private long elementOffset; private long getElementOffset(int ordinal, int elementSize) { - return elementOffset + ordinal * elementSize; + return elementOffset + ordinal * (long)elementSize; } public Object getBaseObject() { return baseObject; } @@ -414,7 +414,7 @@ public byte[] toByteArray() { public short[] toShortArray() { short[] values = new short[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L); return values; } @@ -422,7 +422,7 @@ public short[] toShortArray() { public int[] toIntArray() { int[] values = new int[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -430,7 +430,7 @@ public int[] toIntArray() { public long[] toLongArray() { long[] values = new long[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L); return values; } @@ -438,7 +438,7 @@ public long[] toLongArray() { public float[] toFloatArray() { float[] values = new float[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -446,14 +446,14 @@ public float[] toFloatArray() { public double[] toDoubleArray() { double[] values = new double[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L); return values; } - private static UnsafeArrayData fromPrimitiveArray( + public static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = elementSize * length; + final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + @@ -463,14 +463,27 @@ private static UnsafeArrayData fromPrimitiveArray( final long[] data = new long[(int)totalSizeInLongs]; Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); - Platform.copyMemory(arr, offset, data, - Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + if (arr != null) { + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + } UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } + public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { + return fromPrimitiveArray(null, offset, length, elementSize); + } + + public static boolean shouldUseGenericArrayData(int elementSize, int length) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + return totalSizeInLongs > Integer.MAX_VALUE / 8; + } + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 29a1411241cf6..469b0e60cc9a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -62,6 +62,8 @@ */ public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { + public static final int WORD_SIZE = 8; + ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index 905e6820ce6e2..c823de4810f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8 + klen + vlen + 8; + final long recordLength = 8L + klen + vlen + 8; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 537ef244b7e81..6a52a5b0e0664 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -35,6 +35,7 @@ final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + // buffer is guarantee to be word-aligned since UnsafeRow assumes each field is word-aligned. private byte[] buffer; private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; @@ -52,7 +53,8 @@ final class BufferHolder { "too many fields (number of fields: " + row.numFields() + ")"); } this.fixedSize = bitsetWidthInBytes + 8 * row.numFields(); - this.buffer = new byte[fixedSize + initialSize]; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(fixedSize + initialSize); + this.buffer = new byte[roundedSize]; this.row = row; this.row.pointTo(buffer, buffer.length); } @@ -61,8 +63,12 @@ final class BufferHolder { * Grows the buffer by at least neededSize and points the row to the buffer. */ void grow(int neededSize) { + if (neededSize < 0) { + throw new IllegalArgumentException( + "Cannot grow BufferHolder by size " + neededSize + " because the size is negative"); + } if (neededSize > ARRAY_MAX - totalSize()) { - throw new UnsupportedOperationException( + throw new IllegalArgumentException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } @@ -70,7 +76,8 @@ void grow(int neededSize) { if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; - final byte[] tmp = new byte[newLength]; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(newLength); + final byte[] tmp = new byte[roundedSize]; Platform.copyMemory( buffer, Platform.BYTE_ARRAY_OFFSET, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index d224332d8a6c9..023ec139652c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -21,6 +21,9 @@ import java.io.Reader; import javax.xml.namespace.QName; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpression; @@ -37,9 +40,15 @@ * This is based on Hive's UDFXPathUtil implementation. */ public class UDFXPathUtil { + public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/"; + public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities"; + public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities"; + private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + private DocumentBuilder builder = null; private XPath xpath = XPathFactory.newInstance().newXPath(); private ReusableStringReader reader = new ReusableStringReader(); private InputSource inputSource = new InputSource(reader); + private XPathExpression expression = null; private String oldPath = null; @@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE return null; } + if (builder == null){ + try { + initializeDocumentBuilderFactory(); + builder = dbf.newDocumentBuilder(); + } catch (ParserConfigurationException e) { + throw new RuntimeException( + "Error instantiating DocumentBuilder, cannot build xml parser", e); + } + } + reader.set(xml); try { - return expression.evaluate(inputSource, qname); + return expression.evaluate(builder.parse(inputSource), qname); } catch (XPathExpressionException e) { throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); + } catch (Exception e) { + throw new RuntimeException("Error loading expression '" + oldPath + "'", e); } } + private void initializeDocumentBuilderFactory() throws ParserConfigurationException { + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false); + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false); + } + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index bb77b5bf6de2a..40c2cc806e87a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -22,12 +22,10 @@ public final class RecordBinaryComparator extends RecordComparator { - // TODO(jiangxb) Add test suite for this. @Override public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { int i = 0; - int res = 0; // If the arrays have different length, the longer one is larger. if (leftLen != rightLen) { @@ -40,27 +38,33 @@ public int compare( // check if stars align and we can get both offsets to be aligned if ((leftOff % 8) == (rightOff % 8)) { while ((leftOff + i) % 8 != 0 && i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - res = (int) ((Platform.getLong(leftObj, leftOff + i) - - Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); - if (res != 0) return res; + final long v1 = Platform.getLong(leftObj, leftOff + i); + final long v2 = Platform.getLong(rightObj, rightOff + i); + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 8; } } // this will finish off the unaligned comparisons, or do the entire aligned comparison // whichever is needed. while (i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index ccdb6bc5d4b7c..7b02317b8538f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -68,10 +68,10 @@ import org.apache.spark.sql.types._ */ @Experimental @InterfaceStability.Evolving -@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + - "(Int, String, etc) and Product types (case classes) are supported by importing " + - "spark.implicits._ Support for serializing other types will be added in future " + - "releases.") +@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + + "classes) are supported by importing spark.implicits._ Support for serializing other types " + + "will be added in future releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 0b95a8821b05a..b47ec0b72c638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -132,7 +132,7 @@ object Encoders { * - primitive types: boolean, int, double, etc. * - boxed types: Boolean, Integer, Double, etc. * - String - * - java.math.BigDecimal + * - java.math.BigDecimal, java.math.BigInteger * - time related: java.sql.Date, java.sql.Timestamp * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 474ec592201d9..6f5fbdd79e668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -170,6 +170,9 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to an array of ${elementType.catalogString}") } } @@ -206,6 +209,10 @@ object CatalystTypeConverters { scalaValue match { case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + "cannot be converted to a map type with " + + s"key type (${keyType.catalogString}) and value type (${valueType.catalogString})") } } @@ -252,6 +259,9 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${structType.catalogString}") } override def toScala(row: InternalRow): Row = { @@ -276,6 +286,10 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case chr: Char => UTF8String.fromString(chr.toString) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the string type") } override def toScala(catalystValue: UTF8String): String = if (catalystValue == null) null else catalystValue.toString @@ -309,6 +323,9 @@ object CatalystTypeConverters { case d: JavaBigDecimal => Decimal(d) case d: JavaBigInteger => Decimal(d) case d: Decimal => d + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${dataType.catalogString}") } decimal.toPrecision(dataType.precision, dataType.scale) } @@ -414,6 +431,12 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case (keys: Array[_], values: Array[_]) => + // case for mapdata with duplicate keys + new ArrayBasedMapData( + new GenericArrayData(keys.map(convertToCatalyst)), + new GenericArrayData(values.map(convertToCatalyst)) + ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 818cc2fb1e8a8..0238d57de2446 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -709,6 +709,8 @@ object ScalaReflection extends ScalaReflection { def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => s.toAttributes + case others => + throw new UnsupportedOperationException(s"Attributes for type $others is not supported") } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ @@ -798,7 +800,12 @@ object ScalaReflection extends ScalaReflection { * Whether the fields of the given type is defined entirely by its constructor parameters. */ def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { - tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + tpe.dealias match { + // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. + case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head) + case _ => tpe.dealias <:< localTypeOf[Product] || + tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + } } private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", @@ -846,6 +853,19 @@ object ScalaReflection extends ScalaReflection { } } + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { if (arguments != Nil) { arguments.map(e => dataTypeJavaClass(e.dataType)) @@ -912,15 +932,6 @@ trait ScalaReflection { tpe.dealias.erasure.typeSymbol.asClass.fullName } - /** - * Returns classes of input parameters of scala function object. - */ - def getParameterTypes(func: AnyRef): Seq[Class[_]] = { - val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) - assert(methods.length == 1) - methods.head.getParameterTypes - } - /** * Returns the parameter names and types for the primary constructor of this type. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e821e96522f7c..580133dd971b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -27,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -99,11 +102,11 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } - def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer { val analyzed = execute(plan) try { checkAnalysis(analyzed) - EliminateBarriers(analyzed) + analyzed } catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) @@ -142,6 +145,7 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf), + ResolveHints.ResolveCoalesceHints, ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, LookupFunctions), @@ -172,13 +176,16 @@ class Analyzer( ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: + ResolveOutputRelation :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveHigherOrderFunctions(catalog) :: + ResolveLambdaVariables(conf) :: ResolveTimeZone(conf) :: - ResolvedUuidExpressions :: + ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -200,7 +207,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -210,8 +217,8 @@ class Analyzer( } def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { - plan transformDown { - case u : UnresolvedRelation => + plan resolveOperatorsDown { + case u: UnresolvedRelation => cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) .map(_._2).getOrElse(u) case other => @@ -228,19 +235,16 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. - case WithWindowDefinition(windowDefinitions, child) => - child.transform { - case p => p.transformExpressions { - case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => - val errorMessage = - s"Window specification $windowName is not defined in the WINDOW clause." - val windowSpecDefinition = - windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) - WindowExpression(c, windowSpecDefinition) - } - } + case WithWindowDefinition(windowDefinitions, child) => child.resolveExpressions { + case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => + val errorMessage = + s"Window specification $windowName is not defined in the WINDOW clause." + val windowSpecDefinition = + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) + WindowExpression(c, windowSpecDefinition) + } } } @@ -268,16 +272,16 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) - if child.resolved && hasUnresolvedAlias(groupByExprs) => - Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -439,17 +443,35 @@ class Analyzer( child: LogicalPlan): LogicalPlan = { val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() + // In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and + // can be null. In such case, we derive the groupByExprs from the user supplied values for + // grouping sets. + val finalGroupByExpressions = if (groupByExprs == Nil) { + selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) => + // Only unique expressions are included in the group by expressions and is determined + // based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results + // in grouping expression (a * b) + if (result.find(_.semanticEquals(currentExpr)).isDefined) { + result + } else { + result :+ currentExpr + } + } + } else { + groupByExprs + } + // Expand works by setting grouping expressions to null as determined by the // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate // instead of the original value we need to create new aliases for all group by expressions // that will only be used for the intended purpose. - val groupByAliases = constructGroupByAlias(groupByExprs) + val groupByAliases = constructGroupByAlias(finalGroupByExpressions) val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid) val groupingAttrs = expand.output.drop(child.output.length) val aggregations = constructAggregateExprs( - groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid) + finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid) Aggregate(groupingAttrs, aggregations, expand) } @@ -470,7 +492,7 @@ class Analyzer( } // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -503,14 +525,46 @@ class Analyzer( } object ResolvePivot extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) - | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) + || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) + || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p + case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + if (!RowOrdering.isOrderable(pivotColumn.dataType)) { + throw new AnalysisException( + s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.") + } + // Check all aggregate expressions. + aggregates.foreach(checkValidAggregateExpression) + // Check all pivot values are literal and match pivot column data type. + val evalPivotValues = pivotValues.map { value => + val foldable = value match { + case Alias(v, _) => v.foldable + case _ => value.foldable + } + if (!foldable) { + throw new AnalysisException( + s"Literal expressions required for pivot values, found '$value'") + } + if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { + throw new AnalysisException(s"Invalid pivot value '$value': " + + s"value data type ${value.dataType.simpleString} does not match " + + s"pivot column data type ${pivotColumn.dataType.catalogString}") + } + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + } + // Group-by expressions coming from SQL are implicit and need to be deduced. + val groupByExprs = groupByExprsOpt.getOrElse( + (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 - def outputName(value: Literal, aggregate: Expression): String = { - val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) - val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") + def outputName(value: Expression, aggregate: Expression): String = { + val stringValue = value match { + case n: NamedExpression => n.name + case _ => + val utf8Value = + Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + Option(utf8Value).map(_.toString).getOrElse("null") + } if (singleAgg) { stringValue } else { @@ -531,9 +585,8 @@ class Analyzer( } val bigGroup = groupByExprs :+ namedPivotCol val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) - val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) + Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } @@ -548,8 +601,12 @@ class Analyzer( Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => - def ifExpr(expr: Expression) = { - If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) + def ifExpr(e: Expression) = { + If( + EqualNullSafe( + pivotColumn, + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))), + e, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { @@ -568,16 +625,25 @@ class Analyzer( // TODO: Don't construct the physical container until after analysis. case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") - } Alias(filteredAggregate, outputName(value, aggregate))() } } Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } + + // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF. + // TODO: Support Pandas UDF. + private def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis. + case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) => + failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.") + case e: Attribute => + failAnalysis( + s"Aggregate expression required for pivot, but '${e.sql}' " + + s"did not appear in any aggregate function.") + case e => e.children.foreach(checkValidAggregateExpression) + } } /** @@ -637,7 +703,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -661,13 +727,13 @@ class Analyzer( try { catalog.lookupRelation(tableIdentWithDb) } catch { - case _: NoSuchTableException => - u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") + case e: NoSuchTableException => + u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e) // If the database is defined and that database is not found, throw an AnalysisException. // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exist.") + s"database ${e.db} doesn't exist.", e) } } @@ -698,12 +764,6 @@ class Analyzer( s"between $left and $right") right.collect { - // For `AnalysisBarrier`, recursively de-duplicate its child. - case oldVersion: AnalysisBarrier - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = dedupRight(left, oldVersion.child) - (oldVersion, AnalysisBarrier(newVersion)) - // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -723,6 +783,10 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) @@ -804,7 +868,7 @@ class Analyzer( private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan transformDown { case currentFragment => + plan resolveOperatorsDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) @@ -815,6 +879,7 @@ class Analyzer( } private def resolve(e: Expression, q: LogicalPlan): Expression = e match { + case f: LambdaFunction if !f.bound => f case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = @@ -830,7 +895,7 @@ class Analyzer( case _ => e.mapChildren(resolve(_, q)) } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -858,11 +923,10 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) - case i @ Intersect(left, right) if !i.duplicateResolved => + case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) - case i @ Except(left, right) if !i.duplicateResolved => - i.copy(right = dedupRight(left, right)) - + case e @ Except(left, right, _) if !e.duplicateResolved => + e.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => @@ -1025,7 +1089,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1081,7 +1145,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1105,12 +1169,12 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions - case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) + if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) val ordering = newOrder.map(_.asInstanceOf[SortOrder]) if (child.output == newChild.output) { @@ -1121,7 +1185,7 @@ class Analyzer( Project(child.output, newSort) } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) if (child.output == newChild.output) { f.copy(condition = newCond.head) @@ -1132,29 +1196,34 @@ class Analyzer( } } + /** + * This method tries to resolve expressions and find missing attributes recursively. Specially, + * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved + * attributes which are missed from child output. This method tries to find the missing + * attributes out and add into the projection. + */ private def resolveExprsAndAddMissingAttrs( exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - if (exprs.forall(_.resolved)) { - // All given expressions are resolved, no need to continue anymore. + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { (exprs, plan) } else { plan match { - // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via - // its child. - case barrier: AnalysisBarrier => - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) - (newExprs, AnalysisBarrier(newChild)) - case p: Project => + // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) @@ -1189,16 +1258,46 @@ class Analyzer( * only performs simple existence check according to the function identifier to quickly identify * undefined functions without triggering relation resolution, which may incur potentially * expensive partition/schema discovery process in some cases. - * + * In order to avoid duplicate external functions lookup, the external function identifier will + * store in the local hash set externalFunctionNameSet. * @see [[ResolveFunctions]] * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { - case f: UnresolvedFunction if !catalog.functionExists(f.name) => - withPosition(f) { - throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) - } + override def apply(plan: LogicalPlan): LogicalPlan = { + val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() + plan.resolveExpressions { + case f: UnresolvedFunction + if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f + case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f + case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) => + externalFunctionNameSet.add(normalizeFuncName(f.name)) + f + case f: UnresolvedFunction => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase), + f.name.funcName) + } + } + } + + def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { + val funcName = if (conf.caseSensitiveAnalysis) { + name.funcName + } else { + name.funcName.toLowerCase(Locale.ROOT) + } + + val databaseName = name.database match { + case Some(a) => formatDatabaseName(a) + case None => catalog.getCurrentDatabase + } + + FunctionIdentifier(funcName, Some(databaseName)) + } + + protected def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } } @@ -1206,7 +1305,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1261,7 +1360,7 @@ class Analyzer( * resolved outer references are wrapped in an [[OuterReference]] */ private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { - plan transformDown { + plan resolveOperatorsDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { case u @ UnresolvedAttribute(nameParts) => @@ -1332,18 +1431,33 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved => + case InSubquery(values, l @ ListQuery(_, _, exprId, _)) + if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) - In(value, Seq(expr)) + val subqueryOutput = expr.plan.output + val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery]) + if (values.length != subqueryOutput.length) { + throw new AnalysisException( + s"""Cannot analyze ${resolvedIn.sql}. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length} + |#columns in right hand side: ${subqueryOutput.length} + |Left side columns: + |[${values.map(_.sql).mkString(", ")}] + |Right side columns: + |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin) + } + resolvedIn } } /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1359,7 +1473,7 @@ class Analyzer( */ object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => // Resolves output attributes if a query has alias names in its subquery: // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) @@ -1382,7 +1496,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1408,9 +1522,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case Filter(cond, AnalysisBarrier(agg: Aggregate)) => - apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1468,13 +1580,15 @@ class Analyzer( case ae: AnalysisException => f } - case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => - apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { - val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + // If a sort order is unresolved, containing references not in aggregate, or containing + // `AggregateExpression`, we need to push down it to the underlying aggregate operator. + val unresolvedSortOrders = sortOrder.filter { s => + !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + } val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) @@ -1553,11 +1667,13 @@ class Analyzer( expr.find(_.isInstanceOf[Generator]).isDefined } - private def hasNestedGenerator(expr: NamedExpression): Boolean = expr match { - case UnresolvedAlias(_: Generator, _) => false - case Alias(_: Generator, _) => false - case MultiAlias(_: Generator, _) => false - case other => hasGenerator(other) + private def hasNestedGenerator(expr: NamedExpression): Boolean = { + CleanupAliases.trimNonTopLevelAliases(expr) match { + case UnresolvedAlias(_: Generator, _) => false + case Alias(_: Generator, _) => false + case MultiAlias(_: Generator, _) => false + case other => hasGenerator(other) + } } private def trimAlias(expr: NamedExpression): Expression = expr match { @@ -1583,7 +1699,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1598,24 +1714,26 @@ class Analyzer( // Holds the resolved generator, if one exists in the project list. var resolvedGenerator: Generate = null - val newProjectList = projectList.flatMap { - case AliasedGenerator(generator, names, outer) if generator.childrenResolved => - // It's a sanity check, this should not happen as the previous case will throw - // exception earlier. - assert(resolvedGenerator == null, "More than one generator found in SELECT.") - - resolvedGenerator = - Generate( - generator, - unrequiredChildIndex = Nil, - outer = outer, - qualifier = None, - generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), - child) - - resolvedGenerator.generatorOutput - case other => other :: Nil - } + val newProjectList = projectList + .map(CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + .flatMap { + case AliasedGenerator(generator, names, outer) if generator.childrenResolved => + // It's a sanity check, this should not happen as the previous case will throw + // exception earlier. + assert(resolvedGenerator == null, "More than one generator found in SELECT.") + + resolvedGenerator = + Generate( + generator, + unrequiredChildIndex = Nil, + outer = outer, + qualifier = None, + generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), + child) + + resolvedGenerator.generatorOutput + case other => other :: Nil + } if (resolvedGenerator != null) { Project(newProjectList, resolvedGenerator) @@ -1641,7 +1759,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1682,7 +1800,7 @@ class Analyzer( */ object FixNullability extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { @@ -1724,15 +1842,16 @@ class Analyzer( * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for * all regular expressions. - * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. - * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts - * it into the plan tree. + * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s + * and [[WindowFunctionType]]s. + * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a + * [[Window]] operator and inserts it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = - projectList.exists(hasWindowFunction) + private def hasWindowFunction(exprs: Seq[Expression]): Boolean = + exprs.exists(hasWindowFunction) - private def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: Expression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -1815,6 +1934,10 @@ class Analyzer( seenWindowAggregates += newAgg WindowExpression(newAgg, spec) + case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => + failAnalysis("It is not allowed to use a window function inside an aggregate " + + "function. Please use the inner window function in a sub-query.") + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => @@ -1882,7 +2005,7 @@ class Analyzer( s"Please file a bug report with this error message, stack trace, and the query.") } else { val spec = distinctWindowSpec.head - (spec.partitionSpec, spec.orderSpec) + (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr)) } }.toSeq @@ -1890,7 +2013,7 @@ class Analyzer( // setting this to the child of the next Window operator. val windowOps = groupedWindowExpressions.foldLeft(child) { - case (last, ((partitionSpec, orderSpec), windowExpressions)) => + case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => Window(windowExpressions, partitionSpec, orderSpec, last) } @@ -1901,7 +2024,10 @@ class Analyzer( // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + + case Filter(condition, _) if hasWindowFunction(condition) => + failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. @@ -1958,7 +2084,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -1997,15 +2123,16 @@ class Analyzer( } /** - * Set the seed for random number generation in Uuid expressions. + * Set the seed for random number generation. */ - object ResolvedUuidExpressions extends Rule[LogicalPlan] { + object ResolveRandomSeed extends Rule[LogicalPlan] { private lazy val random = new Random() - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if p.resolved => p case p => p transformExpressionsUp { case Uuid(None) => Uuid(Some(random.nextLong())) + case Shuffle(child, None) => Shuffle(child, Some(random.nextLong())) } } } @@ -2017,23 +2144,39 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { - case udf @ ScalaUDF(func, _, inputs, _, _, _, _) => - val parameterTypes = ScalaReflection.getParameterTypes(func) - assert(parameterTypes.length == inputs.length) + case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) => + if (nullableTypes.isEmpty) { + // If no nullability info is available, do nothing. No fields will be specially + // checked for null in the plan. If nullability info is incorrect, the results + // of the UDF could be wrong. + udf + } else { + // Otherwise, add special handling of null for fields that can't accept null. + // The result of operations like this, when passed null, is generally to return null. + assert(nullableTypes.length == inputs.length) - val inputsNullCheck = parameterTypes.zip(inputs) // TODO: skip null handling for not-nullable primitive inputs after we can completely // trust the `nullable` information. - // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } - .filter { case (cls, _) => cls.isPrimitive } - .map { case (_, expr) => IsNull(expr) } - .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) - inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + val inputsNullCheck = nullableTypes.zip(inputs) + .filter { case (nullable, _) => !nullable } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + // Once we add an `If` check above the udf, it is safe to mark those checked inputs + // as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning + // branch of `If` will be called if any of these checked inputs is null. Thus we can + // prevent this rule from being applied repeatedly. + val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) => + if (nullable) expr else KnownNotNull(expr) + } + inputsNullCheck + .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) + .getOrElse(udf) + } } } } @@ -2042,25 +2185,21 @@ class Analyzer( * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case logical: LogicalPlan => logical transformExpressions { - case WindowExpression(wf: WindowFunction, - WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != UnspecifiedFrame && wf.frame != f => - failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") - case WindowExpression(wf: WindowFunction, - s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, s @ WindowSpecDefinition(_, _, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => - WindowExpression(wf, s.copy(frameSpecification = wf.frame)) - case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = if (o.nonEmpty) { - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - we.copy(windowSpec = s.copy(frameSpecification = frame)) - } + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } + we.copy(windowSpec = s.copy(frameSpecification = frame)) } } @@ -2068,16 +2207,14 @@ class Analyzer( * Check and add order to [[AggregateWindowFunction]]s. */ object ResolveWindowOrder extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case logical: LogicalPlan => logical transformExpressions { - case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => - failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + - s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + - s"ORDER BY window_ordering) from table") - case WindowExpression(rank: RankLike, spec) if spec.resolved => - val order = spec.orderSpec.map(_.child) - WindowExpression(rank.withOrder(order), spec) - } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + + s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + + s"ORDER BY window_ordering) from table") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) } } @@ -2086,8 +2223,8 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case j @ Join(left, right, UsingJoin(joinType, usingCols), _) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => @@ -2097,6 +2234,102 @@ class Analyzer( } } + /** + * Resolves columns of an output table from the data in a logical plan. This rule will: + * + * - Reorder columns when the write is by name + * - Insert safe casts when data types do not match + * - Insert aliases when column names do not match + * - Detect plans that are not compatible with the output table and throw AnalysisException + */ + object ResolveOutputRelation extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case append @ AppendData(table, query, isByName) + if table.resolved && query.resolved && !append.resolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + append.copy(query = projection) + } else { + append + } + } + + def resolveOutputColumns( + tableName: String, + expected: Seq[Attribute], + query: LogicalPlan, + byName: Boolean): LogicalPlan = { + + if (expected.size < query.output.size) { + throw new AnalysisException( + s"""Cannot write to '$tableName', too many data columns: + |Table columns: ${expected.map(c => s"'${c.name}'").mkString(", ")} + |Data columns: ${query.output.map(c => s"'${c.name}'").mkString(", ")}""".stripMargin) + } + + val errors = new mutable.ArrayBuffer[String]() + val resolved: Seq[NamedExpression] = if (byName) { + expected.flatMap { tableAttr => + query.resolveQuoted(tableAttr.name, resolver) match { + case Some(queryExpr) => + checkField(tableAttr, queryExpr, err => errors += err) + case None => + errors += s"Cannot find data for output column '${tableAttr.name}'" + None + } + } + + } else { + if (expected.size > query.output.size) { + throw new AnalysisException( + s"""Cannot write to '$tableName', not enough data columns: + |Table columns: ${expected.map(c => s"'${c.name}'").mkString(", ")} + |Data columns: ${query.output.map(c => s"'${c.name}'").mkString(", ")}""" + .stripMargin) + } + + query.output.zip(expected).flatMap { + case (queryExpr, tableAttr) => + checkField(tableAttr, queryExpr, err => errors += err) + } + } + + if (errors.nonEmpty) { + throw new AnalysisException( + s"Cannot write incompatible data to table '$tableName':\n- ${errors.mkString("\n- ")}") + } + + Project(resolved, query) + } + + private def checkField( + tableAttr: Attribute, + queryExpr: NamedExpression, + addError: String => Unit): Option[NamedExpression] = { + + // run the type check first to ensure type errors are present + val canWrite = DataType.canWrite( + queryExpr.dataType, tableAttr.dataType, resolver, tableAttr.name, addError) + + if (queryExpr.nullable && !tableAttr.nullable) { + addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'") + None + + } else if (!canWrite) { + None + + } else { + // always add an UpCast. it will be removed in the optimizer if it is unnecessary. + Some(Alias( + UpCast(queryExpr, tableAttr.dataType, Seq()), tableAttr.name + )( + explicitMetadata = Option(tableAttr.metadata) + )) + } + } + } + private def commonNaturalJoinProcessing( left: LogicalPlan, right: LogicalPlan, @@ -2151,7 +2384,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2176,7 +2409,7 @@ class Analyzer( } expr case other => - throw new AnalysisException("need an array field but got " + other.simpleString) + throw new AnalysisException("need an array field but got " + other.catalogString) } } validateNestedTupleFields(result) @@ -2185,8 +2418,8 @@ class Analyzer( } private def fail(schema: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + - "but failed as the number of fields does not line up.") + throw new AnalysisException(s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}" + + ", but failed as the number of fields does not line up.") } /** @@ -2237,7 +2470,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2265,13 +2498,13 @@ class Analyzer( case e => e.sql } throw new AnalysisException(s"Cannot up cast $fromStr from " + - s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + s"${from.dataType.catalogString} to ${to.catalogString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2293,8 +2526,12 @@ class Analyzer( * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child) => child + // This is also called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. + def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan transformUp { + case SubqueryAlias(_, child) => child + } } } @@ -2302,7 +2539,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { * Removes [[Union]] operators from the plan if it just has one child. */ object EliminateUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Union(children) if children.size == 1 => children.head } } @@ -2318,6 +2555,7 @@ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { e.transformDown { case Alias(child, _) => child + case MultiAlias(child, _) => child } } @@ -2327,10 +2565,12 @@ object CleanupAliases extends Rule[LogicalPlan] { exprId = a.exprId, qualifier = a.qualifier, explicitMetadata = Some(a.metadata)) + case a: MultiAlias => + a.copy(child = trimAliases(a.child)) case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2340,7 +2580,7 @@ object CleanupAliases extends Rule[LogicalPlan] { val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) - case w @ Window(windowExprs, partitionSpec, orderSpec, child) => + case Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) Window(cleanedWindowExprs, partitionSpec.map(trimAliases), @@ -2359,19 +2599,12 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** Remove the barrier nodes of analysis */ -object EliminateBarriers extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case AnalysisBarrier(child) => child - } -} - /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. */ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case EventTimeWatermark(_, _, child) if !child.isStreaming => child } } @@ -2416,7 +2649,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2504,7 +2737,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. */ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: CreateNamedStruct if !e.resolved => val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => @@ -2556,7 +2789,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { private def updateOuterReferenceInSubquery( plan: LogicalPlan, refExprs: Seq[Expression]): LogicalPlan = { - plan transformAllExpressions { case e => + plan resolveExpressions { case e => val outerAlias = refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) outerAlias match { @@ -2567,7 +2800,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { + plan resolveOperators { case f @ Filter(_, a: Aggregate) if f.resolved => f transformExpressions { case s: SubqueryExpression if s.children.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad82..6a91d556b2f3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -66,11 +67,15 @@ trait CheckAnalysis extends PredicateHelper { limitExpr.sql) case e if e.dataType != IntegerType => failAnalysis( s"The limit expression must be integer type, but got " + - e.dataType.simpleString) - case e if e.eval().asInstanceOf[Int] < 0 => failAnalysis( - "The limit expression must be equal to or greater than 0, but got " + - e.eval().asInstanceOf[Int]) - case e => // OK + e.dataType.catalogString) + case e => + e.eval() match { + case null => failAnalysis( + s"The evaluated limit expression must not be null, but got ${limitExpr.sql}") + case v: Int if v < 0 => failAnalysis( + s"The limit expression must be equal to or greater than 0, but got $v") + case _ => // OK + } } } @@ -78,10 +83,27 @@ trait CheckAnalysis extends PredicateHelper { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + + case p if p.analyzed => // Skip already analyzed sub-plans + case u: UnresolvedRelation => u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") case operator: LogicalPlan => + // Check argument data types of higher-order functions downwards first. + // If the arguments of the higher-order functions are resolved but the type check fails, + // the argument functions will not get resolved, but we should report the argument type + // check failure instead of claiming the argument functions are unresolved. + operator transformExpressionsDown { + case hof: HigherOrderFunction + if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure => + hof.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + hof.failAnalysis( + s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message") + } + } + operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.qualifiedName).mkString(", ") @@ -95,8 +117,8 @@ trait CheckAnalysis extends PredicateHelper { } case c: Cast if !c.resolved => - failAnalysis( - s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + failAnalysis(s"invalid cast from ${c.child.dataType.catalogString} to " + + c.dataType.catalogString) case g: Grouping => failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") @@ -112,12 +134,19 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") + case _ @ WindowExpression(_: PythonUDF, + WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame)) + if !frame.isUnbounded => + failAnalysis("Only unbounded window frame is supported with Pandas UDFs.") + case w @ WindowExpression(e, s) => // Only allow window functions with an aggregate expression or an offset window - // function. + // function or a Pandas window UDF. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => w + case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } @@ -136,12 +165,12 @@ trait CheckAnalysis extends PredicateHelper { case _ => failAnalysis( s"Event time must be defined on a window or a timestamp, but " + - s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.simpleString}") + s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.catalogString}") } case f: Filter if f.condition.dataType != BooleanType => failAnalysis( s"filter expression '${f.condition.sql}' " + - s"of type ${f.condition.dataType.simpleString} is not a boolean.") + s"of type ${f.condition.dataType.catalogString} is not a boolean.") case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + @@ -150,11 +179,11 @@ trait CheckAnalysis extends PredicateHelper { case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.sql}' " + - s"of type ${condition.dataType.simpleString} is not a boolean.") + s"of type ${condition.dataType.catalogString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { - expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) } def checkValidAggregateExpression(expr: Expression): Unit = expr match { @@ -211,7 +240,7 @@ trait CheckAnalysis extends PredicateHelper { if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( s"expression ${expr.sql} cannot be used as a grouping expression " + - s"because its data type ${expr.dataType.simpleString} is not an orderable " + + s"because its data type ${expr.dataType.catalogString} is not an orderable " + s"data type.") } @@ -231,7 +260,7 @@ trait CheckAnalysis extends PredicateHelper { orders.foreach { order => if (!RowOrdering.isOrderable(order.dataType)) { failAnalysis( - s"sorting is not supported for columns of type ${order.dataType.simpleString}") + s"sorting is not supported for columns of type ${order.dataType.catalogString}") } } @@ -334,7 +363,7 @@ trait CheckAnalysis extends PredicateHelper { val mapCol = mapColumnInSetOperation(o).get failAnalysis("Cannot have map type columns in DataFrame which calls " + s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " + - "is " + mapCol.dataType.simpleString) + "is " + mapCol.dataType.catalogString) case o if o.expressions.exists(!_.deterministic) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && @@ -356,10 +385,11 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { - case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } + + plan.setAnalyzed() } /** @@ -523,9 +553,8 @@ trait CheckAnalysis extends PredicateHelper { var foundNonEqualCorrelatedPred: Boolean = false - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - BooleanSimplification(sub).foreachUp { + // Simplify the predicates before validating any unsupported correlation patterns in the plan. + AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, @@ -627,6 +656,6 @@ trait CheckAnalysis extends PredicateHelper { // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - } + }} } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index ab63131b07573..e511f8064e28a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -82,14 +82,14 @@ object DecimalPrecision extends TypeCoercionRule { PromotePrecision(Cast(e, dataType)) } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) } /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ - private val decimalAndDecimal: PartialFunction[Expression, Expression] = { + private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e @@ -286,7 +286,7 @@ object DecimalPrecision extends TypeCoercionRule { // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. // If we use the default precision and scale for the integer type, 2 is considered a // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), - // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // which is out of range and therefore it will become DECIMAL(38, 7), leading to // potentially loosing 11 digits of the fractional part. Using only the precision needed // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would // become DECIMAL(38, 16), safely having a much lower precision loss. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c41f16c61d7a2..77860e1584f42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -299,6 +299,15 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[RegrCount]("regr_count"), + expression[RegrSXX]("regr_sxx"), + expression[RegrSYY]("regr_syy"), + expression[RegrAvgX]("regr_avgx"), + expression[RegrAvgY]("regr_avgy"), + expression[RegrSXY]("regr_sxy"), + expression[RegrSlope]("regr_slope"), + expression[RegrR2]("regr_r2"), + expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), @@ -401,18 +410,47 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArraysOverlap]("arrays_overlap"), + expression[ArrayIntersect]("array_intersect"), + expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), + expression[ArraySort]("array_sort"), + expression[ArrayExcept]("array_except"), + expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), + expression[MapFromArrays]("map_from_arrays"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), + expression[MapEntries]("map_entries"), + expression[MapFromEntries]("map_from_entries"), + expression[MapConcat]("map_concat"), expression[Size]("size"), + expression[Slice]("slice"), + expression[Size]("cardinality"), + expression[ArraysZip]("arrays_zip"), expression[SortArray]("sort_array"), + expression[Shuffle]("shuffle"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), expression[Concat]("concat"), + expression[Flatten]("flatten"), + expression[Sequence]("sequence"), + expression[ArrayRepeat]("array_repeat"), + expression[ArrayRemove]("array_remove"), + expression[ArrayDistinct]("array_distinct"), + expression[ArrayTransform]("transform"), + expression[MapFilter]("map_filter"), + expression[ArrayFilter]("filter"), + expression[ArrayExists]("exists"), + expression[ArrayAggregate]("aggregate"), + expression[TransformValues]("transform_values"), + expression[TransformKeys]("transform_keys"), + expression[MapZipWith]("map_zip_with"), + expression[ZipWith]("zip_with"), + CreateStruct.registryEntry, // misc functions @@ -474,6 +512,7 @@ object FunctionRegistry { // json expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), + expression[SchemaOfJson]("schema_of_json"), // cast expression[Cast]("cast"), @@ -531,7 +570,9 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val validParametersCount = constructors + .filter(_.getParameterTypes.forall(_ == classOf[Expression])) + .map(_.getParameterCount).distinct.sorted val expectedNumberOfParameters = if (validParametersCount.length == 1) { validParametersCount.head.toString } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala new file mode 100644 index 0000000000000..ad201f947b671 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +trait NamedRelation extends LogicalPlan { + def name: String +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f068bce3e9b69..dbd4ed845e329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.IntegerLiteral import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -85,7 +86,7 @@ object ResolveHints { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. @@ -102,12 +103,38 @@ object ResolveHints { } } + /** + * COALESCE Hint accepts name "COALESCE" and "REPARTITION". + * Its parameter includes a partition number. + */ + object ResolveCoalesceHints extends Rule[LogicalPlan] { + private val COALESCE_HINT_NAMES = Set("COALESCE", "REPARTITION") + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case h: UnresolvedHint if COALESCE_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => + val hintName = h.name.toUpperCase(Locale.ROOT) + val shuffle = hintName match { + case "REPARTITION" => true + case "COALESCE" => false + } + val numPartitions = h.parameters match { + case Seq(IntegerLiteral(numPartitions)) => + numPartitions + case Seq(numPartitions: Int) => + numPartitions + case _ => + throw new AnalysisException(s"$hintName Hint expects a partition number as parameter") + } + Repartition(numPartitions, shuffle, h.child) + } + } + /** * Removes all the hints, used to remove invalid hints provided by the user. * This must be executed after all the other hint rules are executed. */ object RemoveAllHints extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index f2df3e132629f..4edfe507a7580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) validateInputEvaluable(table) @@ -103,7 +103,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas castedExpr.eval() } catch { case NonFatal(ex) => - table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex) } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index a214e59302cd9..983e4b0e901cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ @@ -68,9 +69,11 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { : (ArgumentList, Seq[Any] => LogicalPlan) = { (ArgumentList(args: _*), pf orElse { - case args => - throw new IllegalArgumentException( - "Invalid arguments for resolved function: " + args.mkString(", ")) + case arguments => + // This is caught again by the apply function and rethrow with richer information about + // position, etc, for a better error message. + throw new AnalysisException( + "Invalid arguments for resolved function: " + arguments.mkString(", ")) }) } @@ -103,24 +106,37 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => + // The whole resolution is somewhat difficult to understand here due to too much abstractions. + // We should probably rewrite the following at some point. Reynold was just here to improve + // error messages and didn't have time to do a proper rewrite. val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => + + def failAnalysis(): Nothing = { + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: ($argTypes)""".stripMargin) + } + val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { case Some(casted) => - Some(resolver(casted.map(_.eval()))) + try { + Some(resolver(casted.map(_.eval()))) + } catch { + case e: AnalysisException => + failAnalysis() + } case _ => None } } resolved.headOption.getOrElse { - val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") - u.failAnalysis( - s"""error: table-valued function ${u.functionName} with alternatives: - |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} - |cannot be applied to: (${argTypes})""".stripMargin) + failAnalysis() } case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index f9fd0df9e4010..860d20f897690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cfcbd8db559a3..288b6358fbff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,12 +54,13 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + MapZipWithCoercion :: EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -102,17 +103,7 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - - case _ => None + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } /** Promotes all the way to StringType. */ @@ -158,6 +149,60 @@ object TypeCoercion { case (l, r) => None } + private def findTypeForComplex( + t1: DataType, + t2: DataType, + findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findTypeFunc(et1, et2).map { et => + ArrayType(et, containsNull1 || containsNull2 || + Cast.forceNullable(et1, et) || Cast.forceNullable(et2, et)) + } + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findTypeFunc(kt1, kt2) + .filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) } + .flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2 || + Cast.forceNullable(vt1, vt) || Cast.forceNullable(vt2, vt)) + } + } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + findTypeFunc(field1.dataType, field2.dataType).map { dt => + struct.add(field1.name, dt, field1.nullable || field2.nullable || + Cast.forceNullable(field1.dataType, dt) || Cast.forceNullable(field2.dataType, dt)) + } + case _ => None + } + case _ => None + } + + /** + * The method finds a common type for data types that differ only in nullable, containsNull + * and valueContainsNull flags. If the input types are too different, None is returned. + */ + def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { + if (t1 == t2) { + Some(t1) + } else { + findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags) + } + } + + def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = { + if (types.isEmpty) { + None + } else { + types.tail.foldLeft[Option[DataType]](Some(types.head)) { + case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2) + case _ => None + } + } + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -168,11 +213,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } /** @@ -208,12 +249,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) - .map(ArrayType(_, containsNull1 || containsNull2)) - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { @@ -242,8 +278,25 @@ object TypeCoercion { } } - private def haveSameType(exprs: Seq[Expression]): Boolean = - exprs.map(_.dataType).distinct.length == 1 + /** + * Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull. + */ + def haveSameType(types: Seq[DataType]): Boolean = { + if (types.size <= 1) { + true + } else { + val head = types.head + types.tail.forall(_.sameType(head)) + } + } + + private def castIfNotSameType(expr: Expression, dt: DataType): Expression = { + if (!expr.dataType.sameType(dt)) { + Cast(expr, dt) + } else { + expr + } + } /** * Widens numeric types and converts strings to numbers when appropriate. @@ -273,12 +326,18 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case s @ SetOperation(left, right) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + case s @ Except(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + Except(newChildren.head, newChildren.last, isAll) + + case s @ Intersect(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) - s.makeCopy(Array(newChildren.head, newChildren.last)) + Intersect(newChildren.head, newChildren.last, isAll) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => @@ -346,7 +405,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -398,27 +457,16 @@ object TypeCoercion { * Analysis Exception will be raised at the type checking phase. */ case class InConversion(conf: SQLConf) extends TypeCoercionRule { - private def flattenExpr(expr: Expression): Seq[Expression] = { - expr match { - // Multi columns in IN clause is represented as a CreateNamedStruct. - // flatten the named struct to get the list of expressions. - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(a, Seq(ListQuery(sub, children, exprId, _))) - if !i.resolved && flattenExpr(a).length == sub.output.length => - // LHS is the value expression of IN subquery. - val lhs = flattenExpr(a) - + case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _)) + if !i.resolved && lhs.length == sub.output.length => + // LHS is the value expressions of IN subquery. // RHS is the subquery output. val rhs = sub.output @@ -434,20 +482,13 @@ object TypeCoercion { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } - val castedLhs = lhs.zip(commonTypes).map { + val newLhs = lhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Cast(e, dt) case (e, _) => e } - // Before constructing the In expression, wrap the multi values in LHS - // in a CreatedNamedStruct. - val newLhs = castedLhs match { - case Seq(lhs) => lhs - case _ => CreateStruct(castedLhs) - } - val newSub = Project(castedRhs, sub) - In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) + InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output)) } else { i } @@ -467,7 +508,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -508,46 +549,63 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends TypeCoercionRule { + override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !haveSameType(children) => + case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType))) case None => a } case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && - !haveSameType(children) => + !haveSameType(c.inputTypesForMerging) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) case None => c } + case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) && + ArrayType.acceptsType(arr.dataType) => + val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull + ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match { + case Some(castedArr) => ArrayJoin(castedArr, d, nr) + case None => aj + } + + case s @ Sequence(_, _, _, timeZoneId) + if !haveSameType(s.coercibleChildren.map(_.dataType)) => + val types = s.coercibleChildren.map(_.dataType) + findWiderCommonType(types) match { + case Some(widerDataType) => s.castChildrenTo(widerDataType) + case None => s + } + + case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(m.inputTypesForMerging) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) + case None => m + } + case m @ CreateMap(children) if m.keys.length == m.values.length && - (!haveSameType(m.keys) || !haveSameType(m.values)) => - val newKeys = if (haveSameType(m.keys)) { - m.keys - } else { - val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) - case None => m.keys - } + (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => + val keyTypes = m.keys.map(_.dataType) + val newKeys = findWiderCommonType(keyTypes) match { + case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) + case None => m.keys } - val newValues = if (haveSameType(m.values)) { - m.values - } else { - val types = m.values.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) - case None => m.values - } + val valueTypes = m.values.map(_.dataType) + val newValues = findWiderCommonType(valueTypes) match { + case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) + case None => m.values } CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) @@ -570,27 +628,27 @@ object TypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case c @ Coalesce(es) if !haveSameType(es) => + case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => val types = es.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) + case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) case None => c } // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if // we need to truncate, but we should not promote one side to string if the other side is // string.g - case g @ Greatest(children) if !haveSameType(children) => + case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType))) case None => g } - case l @ Least(children) if !haveSameType(children) => + case l @ Least(children) if !haveSameType(l.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType))) case None => l } @@ -608,7 +666,7 @@ object TypeCoercion { */ object Division extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -631,28 +689,15 @@ object TypeCoercion { */ object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => - var changed = false val newBranches = c.branches.map { case (condition, value) => - if (value.dataType.sameType(commonType)) { - (condition, value) - } else { - changed = true - (condition, Cast(value, commonType)) - } + (condition, castIfNotSameType(value, commonType)) } - val newElseValue = c.elseValue.map { value => - if (value.dataType.sameType(commonType)) { - value - } else { - changed = true - Cast(value, commonType) - } - } - if (changed) CaseWhen(newBranches, newElseValue) else c + val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) + CaseWhen(newBranches, newElseValue) }.getOrElse(c) } } @@ -662,13 +707,13 @@ object TypeCoercion { */ object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. - case i @ If(pred, left, right) if left.dataType != right.dataType => + case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + val newLeft = castIfNotSameType(left, widestType) + val newRight = castIfNotSameType(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => @@ -682,7 +727,7 @@ object TypeCoercion { * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ object StackCoercion extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -702,20 +747,46 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or empty children - case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if conf.concatBinaryAsString || + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val newChildren = c.children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + } } } } + /** + * Coerces key types of two different [[MapType]] arguments of the [[MapZipWith]] expression + * to a common type. + */ + object MapZipWithCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + // Lambda function isn't resolved when the rule is executed. + case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && + MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) => + findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { + case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) && + !Cast.forceNullable(m.rightKeyType, finalKeyType) => + val newLeft = castIfNotSameType( + left, + MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull)) + val newRight = castIfNotSameType( + right, + MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull)) + MapZipWith(newLeft, newRight, function) + case _ => m + } + } + } + /** * Coerces the types of [[Elt]] children to expected ones. * @@ -724,22 +795,24 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or not enough children - case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children) => - val index = children.head - val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) - val newInputs = if (conf.eltOutputAsString || + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - children.tail.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail } - } else { - children.tail - } - c.copy(children = newIndex +: newInputs) + c.copy(children = newIndex +: newInputs) + } } } } @@ -752,7 +825,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -768,12 +841,33 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + + private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) + override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp + // string is in a specific timezone, so the string itself should not contain timezone. + // TODO: We should move the type coercion logic to expressions instead of a central + // place to put all the rules. + case e: FromUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + + case e: ToUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { @@ -790,7 +884,7 @@ object TypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -806,6 +900,9 @@ object TypeCoercion { } e.withNewChildren(children) } + } + + object ImplicitTypeCasts { /** * Given an expected data type, try to cast the expression and return the cast expression. @@ -888,7 +985,7 @@ object TypeCoercion { */ object WindowFrameCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -926,7 +1023,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index ff9d6d7a7dded..cff4cee09427f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** @@ -305,17 +306,19 @@ object UnsupportedOperationChecker { case u: Union if u.children.map(_.isStreaming).distinct.size == 2 => throwError("Union between streaming and batch DataFrames/Datasets is not supported") - case Except(left, right) if right.isStreaming => + case Except(left, right, _) if right.isStreaming => throwError("Except on a streaming DataFrame/Dataset on the right is not supported") - case Intersect(left, right) if left.isStreaming && right.isStreaming => + case Intersect(left, right, _) if left.isStreaming && right.isStreaming => throwError("Intersect between two streaming DataFrames/Datasets is not supported") case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => - throwError("Limits are not supported on streaming DataFrames/Datasets") + case GlobalLimit(_, _) | LocalLimit(_, _) + if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update => + throwError("Limits are not supported on streaming DataFrames/Datasets in Update " + + "output mode") case Sort(_, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + @@ -345,8 +348,20 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | + _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => + case Repartition(1, false, _) => + case node: Aggregate => + val aboveSinglePartitionCoalesce = node.find { + case Repartition(1, false, _) => true + case _ => false + }.isDefined + + if (!aboveSinglePartitionCoalesce) { + throwError(s"In continuous processing mode, coalesce(1) must be called before " + + s"aggregate operation ${node.nodeName}.") + } case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala new file mode 100644 index 0000000000000..dd08190e1e8a3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Resolve a higher order functions from the catalog. This is different from regular function + * resolution because lambda functions can only be resolved after the function has been resolved; + * so we need to resolve higher order function when all children are either resolved or a lambda + * function. + */ +case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { + case u @ UnresolvedFunction(fn, children, false) + if hasLambdaAndResolvedArguments(children) => + withPosition(u) { + catalog.lookupFunction(fn, children) match { + case func: HigherOrderFunction => func + case other => other.failAnalysis( + "A lambda function should only be used in a higher order function. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a " + + s"higher order function.") + } + } + } + + /** + * Check if the arguments of a function are either resolved or a lambda function. + */ + private def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): Boolean = { + val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) + lambdas.nonEmpty && others.forall(_.resolved) + } +} + +/** + * Resolve the lambda variables exposed by a higher order functions. + * + * This rule works in two steps: + * [1]. Bind the anonymous variables exposed by the higher order function to the lambda function's + * arguments; this creates named and typed lambda variables. The argument names are checked + * for duplicates and the number of arguments are checked during this step. + * [2]. Resolve the used lambda variables used in the lambda function's function expression tree. + * Note that we allow the use of variables from outside the current lambda, this can either + * be a lambda function defined in an outer scope, or a attribute in produced by the plan's + * child. If names are duplicate, the name defined in the most inner scope is used. + */ +case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { + + type LambdaVariableMap = Map[String, NamedExpression] + + private val canonicalizer = { + if (!conf.caseSensitiveAnalysis) { + s: String => s.toLowerCase + } else { + s: String => s + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperators { + case q: LogicalPlan => + q.mapExpressions(resolve(_, Map.empty)) + } + } + + /** + * Create a bound lambda function by binding the arguments of a lambda function to the given + * partial arguments (dataType and nullability only). If the expression happens to be an already + * bound lambda function then we assume it has been bound to the correct arguments and do + * nothing. This function will produce a lambda function with hidden arguments when it is passed + * an arbitrary expression. + */ + private def createLambda( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction if f.bound => f + + case LambdaFunction(function, names, _) => + if (names.size != argInfo.size) { + e.failAnalysis( + s"The number of lambda function arguments '${names.size}' does not " + + "match the number of arguments expected by the higher order function " + + s"'${argInfo.size}'.") + } + + if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) { + e.failAnalysis( + "Lambda function arguments should not have names that are semantically the same.") + } + + val arguments = argInfo.zip(names).map { + case ((dataType, nullable), ne) => + NamedLambdaVariable(ne.name, dataType, nullable) + } + LambdaFunction(function, arguments) + + case _ => + // This expression does not consume any of the lambda's arguments (it is independent). We do + // create a lambda function with default parameters because this is expected by the higher + // order function. Note that we hide the lambda variables produced by this function in order + // to prevent accidental naming collisions. + val arguments = argInfo.zipWithIndex.map { + case ((dataType, nullable), i) => + NamedLambdaVariable(s"col$i", dataType, nullable) + } + LambdaFunction(e, arguments, hidden = true) + } + + /** + * Resolve lambda variables in the expression subtree, using the passed lambda variable registry. + */ + private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match { + case _ if e.resolved => e + + case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess => + h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) + + case l: LambdaFunction if !l.bound => + // Do not resolve an unbound lambda function. If we see such a lambda function this means + // that either the higher order function has yet to be resolved, or that we are seeing + // dangling lambda function. + l + + case l: LambdaFunction if !l.hidden => + val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap + l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap)) + + case u @ UnresolvedAttribute(name +: nestedFields) => + parentLambdaMap.get(canonicalizer(name)) match { + case Some(lambda) => + nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) => + ExtractValue(expr, Literal(fieldName), conf.resolver) + } + case None => u + } + + case _ => + e.mapChildren(resolve(_, parentLambdaMap)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 7731336d247db..354a3fa0602a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -41,6 +41,11 @@ package object analysis { def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } + + /** Fails the analysis at the point where a specific tree node was parsed. */ + def failAnalysis(msg: String, cause: Throwable): Nothing = { + throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause)) + } } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index af1f9165b0044..a27aa845bf0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformAllExpressions(transformTimeZoneExprs) + plan.resolveExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 71e23175168e2..c1ec736c32ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -104,12 +104,12 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override lazy val resolved = false override def newInstance(): UnresolvedAttribute = this override def withNullability(newNullability: Boolean): UnresolvedAttribute = this - override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this + override def withQualifier(newQualifier: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def withMetadata(newMetadata: Metadata): Attribute = this @@ -240,7 +240,7 @@ abstract class Star extends LeafExpression with NamedExpression { override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override lazy val resolved = false @@ -262,17 +262,46 @@ abstract class Star extends LeafExpression with NamedExpression { */ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable { - override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = { + /** + * Returns true if the nameParts match the qualifier of the attribute + * + * There are two checks: i) Check if the nameParts match the qualifier fully. + * E.g. SELECT db.t1.* FROM db1.t1 In this case, the nameParts is Seq("db1", "t1") and + * qualifier of the attribute is Seq("db1","t1") + * ii) If (i) is not true, then check if nameParts is only a single element and it + * matches the table portion of the qualifier + * + * E.g. SELECT t1.* FROM db1.t1 In this case nameParts is Seq("t1") and + * qualifier is Seq("db1","t1") + * SELECT a.* FROM db1.t1 AS a + * In this case nameParts is Seq("a") and qualifier for + * attribute is Seq("a") + */ + private def matchedQualifier( + attribute: Attribute, + nameParts: Seq[String], + resolver: Resolver): Boolean = { + val qualifierList = attribute.qualifier + + val matched = nameParts.corresponds(qualifierList)(resolver) || { + // check if it matches the table portion of the qualifier + if (nameParts.length == 1 && qualifierList.nonEmpty) { + resolver(nameParts.head, qualifierList.last) + } else { + false + } + } + matched + } + + override def expand( + input: LogicalPlan, + resolver: Resolver): Seq[NamedExpression] = { // If there is no table specified, use all input attributes. if (target.isEmpty) return input.output - val expandedAttributes = - if (target.get.size == 1) { - // If there is a table, pick out attributes that are part of this table. - input.output.filter(_.qualifier.exists(resolver(_, target.get.head))) - } else { - List() - } + val expandedAttributes = input.output.filter(matchedQualifier(_, target.get, resolver)) + if (expandedAttributes.nonEmpty) return expandedAttributes // Try to resolve it as a struct expansion. If there is a conflict and both are possible, @@ -316,8 +345,8 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens // If there is no table specified, use all input attributes that match expr case None => input.output.filter(_.name.matches(pattern)) // If there is a table, pick out attributes that are part of this table that match expr - case Some(t) => input.output.filter(_.qualifier.exists(resolver(_, t))) - .filter(_.name.matches(pattern)) + case Some(t) => input.output.filter(a => a.qualifier.nonEmpty && + resolver(a.qualifier.last, t)).filter(_.name.matches(pattern)) } } @@ -345,7 +374,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") @@ -403,7 +432,7 @@ case class UnresolvedAlias( extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override def dataType: DataType = throw new UnresolvedException(this, "dataType") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 20216087b0158..af74693000c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames @@ -76,7 +76,8 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupp // Will throw an AnalysisException if the cast can't perform or might truncate. if (Cast.mayTruncate(originAttr.dataType, attr.dataType)) { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + - s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n") + s"${originAttr.dataType.catalogString} to ${attr.dataType.catalogString} as it " + + s"may truncate\n") } else { Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 45b4f013620c1..1a145c24d78cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchPartitionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -31,10 +30,13 @@ import org.apache.spark.util.ListenerBus * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog - extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { +trait ExternalCatalog { import CatalogTypes.TablePartitionSpec + // -------------------------------------------------------------------------- + // Utils + // -------------------------------------------------------------------------- + protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { throw new NoSuchDatabaseException(db) @@ -63,22 +65,9 @@ abstract class ExternalCatalog // Databases // -------------------------------------------------------------------------- - final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val db = dbDefinition.name - postToAll(CreateDatabasePreEvent(db)) - doCreateDatabase(dbDefinition, ignoreIfExists) - postToAll(CreateDatabaseEvent(db)) - } + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - - final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { - postToAll(DropDatabasePreEvent(db)) - doDropDatabase(db, ignoreIfNotExists, cascade) - postToAll(DropDatabaseEvent(db)) - } - - protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -87,14 +76,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { - val db = dbDefinition.name - postToAll(AlterDatabasePreEvent(db)) - doAlterDatabase(dbDefinition) - postToAll(AlterDatabaseEvent(db)) - } - - protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit + def alterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -110,41 +92,15 @@ abstract class ExternalCatalog // Tables // -------------------------------------------------------------------------- - final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - val tableDefinitionWithVersion = - tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) - postToAll(CreateTablePreEvent(db, name)) - doCreateTable(tableDefinitionWithVersion, ignoreIfExists) - postToAll(CreateTableEvent(db, name)) - } - - protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - - final def dropTable( - db: String, - table: String, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = { - postToAll(DropTablePreEvent(db, table)) - doDropTable(db, table, ignoreIfNotExists, purge) - postToAll(DropTableEvent(db, table)) - } + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - protected def doDropTable( + def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit - final def renameTable(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameTablePreEvent(db, oldName, newName)) - doRenameTable(db, oldName, newName) - postToAll(RenameTableEvent(db, oldName, newName)) - } - - protected def doRenameTable(db: String, oldName: String, newName: String): Unit + def renameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -154,15 +110,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) - doAlterTable(tableDefinition) - postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) - } - - protected def doAlterTable(tableDefinition: CatalogTable): Unit + def alterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -173,22 +121,10 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) - doAlterTableDataSchema(db, table, newDataSchema) - postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) - } - - protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) - doAlterTableStats(db, table, stats) - postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) - } - - protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable @@ -340,49 +276,17 @@ abstract class ExternalCatalog // Functions // -------------------------------------------------------------------------- - final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(CreateFunctionPreEvent(db, name)) - doCreateFunction(db, funcDefinition) - postToAll(CreateFunctionEvent(db, name)) - } + def createFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit + def dropFunction(db: String, funcName: String): Unit - final def dropFunction(db: String, funcName: String): Unit = { - postToAll(DropFunctionPreEvent(db, funcName)) - doDropFunction(db, funcName) - postToAll(DropFunctionEvent(db, funcName)) - } + def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doDropFunction(db: String, funcName: String): Unit - - final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(AlterFunctionPreEvent(db, name)) - doAlterFunction(db, funcDefinition) - postToAll(AlterFunctionEvent(db, name)) - } - - protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit - - final def renameFunction(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameFunctionPreEvent(db, oldName, newName)) - doRenameFunction(db, oldName, newName) - postToAll(RenameFunctionEvent(db, oldName, newName)) - } - - protected def doRenameFunction(db: String, oldName: String, newName: String): Unit + def renameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction def functionExists(db: String, funcName: String): Boolean def listFunctions(db: String, pattern: String): Seq[String] - - override protected def doPostEvent( - listener: ExternalCatalogEventListener, - event: ExternalCatalogEvent): Unit = { - listener.onEvent(event) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala new file mode 100644 index 0000000000000..2f009be5816fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus + +/** + * Wraps an ExternalCatalog to provide listener events. + */ +class ExternalCatalogWithListener(delegate: ExternalCatalog) + extends ExternalCatalog + with ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { + import CatalogTypes.TablePartitionSpec + + def unwrapped: ExternalCatalog = delegate + + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + delegate.createDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + override def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + delegate.dropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } + + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + delegate.alterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + override def getDatabase(db: String): CatalogDatabase = { + delegate.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = { + delegate.databaseExists(db) + } + + override def listDatabases(): Seq[String] = { + delegate.listDatabases() + } + + override def listDatabases(pattern: String): Seq[String] = { + delegate.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = { + delegate.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + val tableDefinitionWithVersion = + tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) + postToAll(CreateTablePreEvent(db, name)) + delegate.createTable(tableDefinitionWithVersion, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + delegate.dropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + delegate.renameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + override def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + delegate.alterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + override def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + delegate.alterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + override def alterTableStats( + db: String, + table: String, + stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + delegate.alterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + override def getTable(db: String, table: String): CatalogTable = { + delegate.getTable(db, table) + } + + override def tableExists(db: String, table: String): Boolean = { + delegate.tableExists(db, table) + } + + override def listTables(db: String): Seq[String] = { + delegate.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = { + delegate.listTables(db, pattern) + } + + override def loadTable( + db: String, + table: String, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) + } + + override def loadPartition( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadPartition( + db, table, loadPath, partition, isOverwrite, inheritTableSpecs, isSrcLocal) + } + + override def loadDynamicPartitions( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + replace: Boolean, + numDP: Int): Unit = { + delegate.loadDynamicPartitions(db, table, loadPath, partition, replace, numDP) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + delegate.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { + delegate.dropPartitions(db, table, partSpecs, ignoreIfNotExists, purge, retainData) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + delegate.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = { + delegate.alterPartitions(db, table, parts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = { + delegate.getPartition(db, table, spec) + } + + override def getPartitionOption( + db: String, + table: String, + spec: TablePartitionSpec): Option[CatalogTablePartition] = { + delegate.getPartitionOption(db, table, spec) + } + + override def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + delegate.listPartitionNames(db, table, partialSpec) + } + + override def listPartitions( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + delegate.listPartitions(db, table, partialSpec) + } + + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + delegate.listPartitionsByFilter(db, table, predicates, defaultTimeZoneId) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + delegate.createFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + override def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + delegate.dropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + delegate.alterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + delegate.renameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = { + delegate.getFunction(db, funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = { + delegate.functionExists(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = { + delegate.listFunctions(db, pattern) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8eacfa058bd52..741dc46b07382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,7 @@ class InMemoryCatalog( } } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = synchronized { @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { @@ -564,24 +564,24 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { + override def dropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def alterFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c390337c03ff5..afb0f009db05c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils @@ -101,6 +101,8 @@ class SessionCatalog( @GuardedBy("this") protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) + private val validNameFormat = "([\\w_]+)".r + /** * Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. @@ -109,7 +111,6 @@ class SessionCatalog( * org.apache.hadoop.hive.metastore.MetaStoreUtils.validateName. */ private def validateName(name: String): Unit = { - val validNameFormat = "([\\w_]+)".r if (!validNameFormat.pattern.matcher(name).matches()) { throw new AnalysisException(s"`$name` is not a valid name for tables/databases. " + "Valid names only contain alphabet characters, numbers and _.") @@ -619,6 +620,7 @@ class SessionCatalog( requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) + validateNewLocationOfRename(oldName, newName) externalCatalog.renameTable(db, oldTableName, newTableName) } else { if (newName.database.isDefined) { @@ -683,6 +685,7 @@ class SessionCatalog( * * If the relation is a view, we generate a [[View]] operator from the view description, and * wrap the logical plan in a [[SubqueryAlias]] which will track the name of the view. + * [[SubqueryAlias]] will also keep track of the name and database(optional) of the table/view * * @param name The name of the table/view that we look up. */ @@ -692,7 +695,7 @@ class SessionCatalog( val table = formatTableName(name.table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(table, viewDef) + SubqueryAlias(table, db, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempViews.contains(table)) { val metadata = externalCatalog.getTable(db, table) @@ -705,9 +708,9 @@ class SessionCatalog( desc = metadata, output = metadata.schema.toAttributes, child = parser.parsePlan(viewText)) - SubqueryAlias(table, child) + SubqueryAlias(table, db, child) } else { - SubqueryAlias(table, UnresolvedCatalogRelation(metadata)) + SubqueryAlias(table, db, UnresolvedCatalogRelation(metadata)) } } else { SubqueryAlias(table, tempViews(table)) @@ -1058,7 +1061,7 @@ class SessionCatalog( } /** - * overwirte a metastore function in the database specified in `funcDefinition`.. + * overwrite a metastore function in the database specified in `funcDefinition`.. * If no database is specified, assume the function is in the current database. */ def alterFunction(funcDefinition: CatalogFunction): Unit = { @@ -1123,13 +1126,22 @@ class SessionCatalog( name: String, clazz: Class[_], input: Seq[Expression]): Expression = { + // Unfortunately we need to use reflection here because UserDefinedAggregateFunction + // and ScalaUDAF are defined in sql/core module. val clsForUDAF = Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") - cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) + val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) - .asInstanceOf[Expression] + .asInstanceOf[ImplicitCastInputTypes] + + // Check input argument size + if (e.inputTypes.size != input.size) { + throw new AnalysisException(s"Invalid number of arguments for function $name. " + + s"Expected: ${e.inputTypes.size}; Found: ${input.size}") + } + e } else { throw new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}'. " + s"Use sparkSession.udf.register(...) instead.") @@ -1192,6 +1204,22 @@ class SessionCatalog( !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } + /** + * Return whether this function has been registered in the function registry of the current + * session. If not existed, return false. + */ + def isRegisteredFunction(name: FunctionIdentifier): Boolean = { + functionRegistry.functionExists(name) + } + + /** + * Returns whether it is a persistent function. If not existed, returns false. + */ + def isPersistentFunction(name: FunctionIdentifier): Boolean = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + databaseExists(db) && externalCatalog.functionExists(db, name.funcName) + } + protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { throw new NoSuchFunctionException( db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) @@ -1366,4 +1394,23 @@ class SessionCatalog( // copy over temporary views tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } + + /** + * Validate the new locatoin before renaming a managed table, which should be non-existent. + */ + private def validateNewLocationOfRename( + oldName: TableIdentifier, + newName: TableIdentifier): Unit = { + val oldTable = getTableMetadata(oldName) + if (oldTable.tableType == CatalogTableType.MANAGED) { + val databaseLocation = + externalCatalog.getDatabase(oldName.database.getOrElse(currentDb)).locationUri + val newTableLocation = new Path(new Path(databaseLocation), formatTableName(newName.table)) + val fs = newTableLocation.getFileSystem(hadoopConf) + if (fs.exists(newTableLocation)) { + throw new AnalysisException(s"Can not rename the managed table('$oldName')" + + s". The associated location('$newTableLocation') already exists.") + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index f3e67dc4e975c..3842d794ba5ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -93,12 +94,16 @@ object CatalogStorageFormat { * @param spec partition spec values indexed by column name * @param storage storage format of the partition * @param parameters some parameters for the partition + * @param createTime creation time of the partition, in milliseconds + * @param lastAccessTime last access time, in milliseconds * @param stats optional statistics (number of rows, total size, etc.) */ case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, storage: CatalogStorageFormat, parameters: Map[String, String] = Map.empty, + createTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, stats: Option[CatalogStatistics] = None) { def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { @@ -109,6 +114,11 @@ case class CatalogTablePartition( if (parameters.nonEmpty) { map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") } + map.put("Created Time", new Date(createTime).toString) + val lastAccess = { + if (-1 == lastAccessTime) "UNKNOWN" else new Date(lastAccessTime).toString + } + map.put("Last Access", lastAccess) stats.foreach(s => map.put("Partition Statistics", s.simpleString)) map } @@ -164,9 +174,12 @@ case class BucketSpec( numBuckets: Int, bucketColumnNames: Seq[String], sortColumnNames: Seq[String]) { - if (numBuckets <= 0 || numBuckets >= 100000) { + def conf: SQLConf = SQLConf.get + + if (numBuckets <= 0 || numBuckets > conf.bucketingMaxBuckets) { throw new AnalysisException( - s"Number of buckets should be greater than 0 but less than 100000. Got `$numBuckets`") + s"Number of buckets should be greater than 0 but less than bucketing.maxBuckets " + + s"(`${conf.bucketingMaxBuckets}`). Got `$numBuckets`") } override def toString: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index efb2eba655e15..d3ccd18d0245e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,7 +88,13 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = In(expr, list) + def in(list: Expression*): Expression = list match { + case Seq(l: ListQuery) => expr match { + case c: CreateNamedStruct => InSubquery(c.valExprs, l) + case other => InSubquery(Seq(other), l) + } + case _ => In(expr, list) + } def like(other: Expression): Expression = Like(expr, other) def rlike(other: Expression): Expression = RLike(expr, other) @@ -149,6 +155,7 @@ package object dsl { } } + def rand(e: Long): Expression = Rand(e) def sum(e: Expression): Expression = Sum(e).toAggregateExpression() def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) def count(e: Expression): Expression = Count(e).toAggregateExpression() @@ -165,6 +172,9 @@ package object dsl { def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct = true) def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) + def coalesce(args: Expression*): Expression = Coalesce(args) + def greatest(args: Expression*): Expression = Greatest(args) + def least(args: Expression*): Expression = Least(args) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) def star(names: String*): Expression = names match { @@ -354,9 +364,11 @@ package object dsl { def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) - def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) + def except(otherPlan: LogicalPlan, isAll: Boolean): LogicalPlan = + Except(logicalPlan, otherPlan, isAll) - def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) + def intersect(otherPlan: LogicalPlan, isAll: Boolean): LogicalPlan = + Intersect(logicalPlan, otherPlan, isAll) def union(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882f0a3d3..cbea3c017a265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -128,7 +128,7 @@ object ExpressionEncoder { case b: BoundReference if b == originalInputObject => newInputObject }) - if (enc.flat) { + val serializerExpr = if (enc.flat) { newSerializer.head } else { // For non-flat encoder, the input object is not top level anymore after being combined to @@ -146,6 +146,7 @@ object ExpressionEncoder { Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) If(nullCheck, Literal.create(null, struct.dataType), struct) } + Alias(serializerExpr, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..3340789398f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4cc84b27d9eb0..77582e10f9ff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -52,17 +53,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = CodeGenerator.javaType(dataType) + val javaType = JavaCode.javaType(dataType) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index d848ba18356d3..fe6db8b344d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -30,6 +30,7 @@ package org.apache.spark.sql.catalyst.expressions * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + * - Elements in [[In]] are reordered by `hashCode`. */ object Canonicalize { def execute(e: Expression): Expression = { @@ -85,6 +86,9 @@ object Canonicalize { case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + // order the list in the In operator + case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) + case _ => e } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12330bfa55ab9..0053503501047 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -133,6 +134,35 @@ object Cast { toPrecedence > 0 && fromPrecedence > toPrecedence } + /** + * Returns true iff we can safely cast the `from` type to `to` type without any truncating or + * precision lose, e.g. int -> long, date -> timestamp. + */ + def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match { + case _ if from == to => true + case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true + case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true + case (from, to) if legalNumericPrecedence(from, to) => true + case (DateType, TimestampType) => true + case (_, StringType) => true + case _ => false + } + + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) + fromPrecedence >= 0 && fromPrecedence < toPrecedence + } + + def canNullSafeCastToDecimal(from: DataType, to: DecimalType): Boolean = from match { + case from: BooleanType if to.isWiderThan(DecimalType.BooleanDecimal) => true + case from: NumericType if to.isWiderThan(from) => true + case from: DecimalType => + // truncating or precision lose + (to.precision - to.scale) > (from.precision - from.scale) + case _ => false // overflow + } + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false @@ -148,7 +178,7 @@ object Cast { case (DateType, _) => true case (_, CalendarIntervalType) => true - case (_, _: DecimalType) => true // overflow + case (_, to: DecimalType) if !canNullSafeCastToDecimal(from, to) => true case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } @@ -181,7 +211,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType.simpleString} to ${dataType.simpleString}") + s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}") } } @@ -623,21 +653,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) + ev.copy(code = eval.code + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` // in parameter list, because the returned code will be put in null safe evaluation region. - private[this] type CastFunction = (String, String, String) => String + private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block private[this] def nullSafeCastFunction( from: DataType, to: DataType, ctx: CodegenContext): CastFunction = to match { - case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" - case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;" case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) @@ -658,18 +689,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) case udt: UserDefinedType[_] if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - (c, evPrim, evNull) => s"$evPrim = $c;" + (c, evPrim, evNull) => code"$evPrim = $c;" case _: UserDefinedType[_] => throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String, - result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { - s""" + private[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue, + result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = { + val javaType = JavaCode.javaType(resultType) + code""" boolean $resultIsNull = $inputIsNull; - ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; + $javaType $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -678,22 +710,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeArrayToStringBuilder( et: DataType, - array: String, - buffer: String, - ctx: CodegenContext): String = { + array: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val elementToStringCode = castToStringCode(et, ctx) val funcName = ctx.freshName("elementToString") - val elementToStringFunc = ctx.addNewFunction(funcName, + val element = JavaCode.variable("element", et) + val elementStr = JavaCode.variable("elementStr", StringType) + val elementToStringFunc = inline"${ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { - | UTF8String elementStr = null; - | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(et)} $element) { + | UTF8String $elementStr = null; + | ${elementToStringCode(element, elementStr, null /* resultIsNull won't be used */)} | return elementStr; |} - """.stripMargin) + """.stripMargin)}" - val loopIndex = ctx.freshName("loopIndex") - s""" + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) + code""" |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { @@ -714,31 +748,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeMapToStringBuilder( kt: DataType, vt: DataType, - map: String, - buffer: String, - ctx: CodegenContext): String = { + map: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { def dataToStringFunc(func: String, dataType: DataType) = { val funcName = ctx.freshName(func) val dataToStringCode = castToStringCode(dataType, ctx) - ctx.addNewFunction(funcName, + val data = JavaCode.variable("data", dataType) + val dataStr = JavaCode.variable("dataStr", StringType) + val functionCall = ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { - | UTF8String dataStr = null; - | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) { + | UTF8String $dataStr = null; + | ${dataToStringCode(data, dataStr, null /* resultIsNull won't be used */)} | return dataStr; |} """.stripMargin) + inline"$functionCall" } val keyToStringFunc = dataToStringFunc("keyToString", kt) val valueToStringFunc = dataToStringFunc("valueToString", vt) - val loopIndex = ctx.freshName("loopIndex") - val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") - val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") - val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) - val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) - s""" + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) + val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) + val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) + val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType)) + val getMapFirstValue = CodeGenerator.getValue(mapValueArray, vt, + JavaCode.literal("0", IntegerType)) + val getMapKeyArray = CodeGenerator.getValue(mapKeyArray, kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(mapValueArray, vt, loopIndex) + code""" |$buffer.append("["); |if ($map.numElements() > 0) { | $buffer.append($keyToStringFunc($getMapFirstKey)); @@ -763,20 +803,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeStructToStringBuilder( st: Seq[DataType], - row: String, - buffer: String, - ctx: CodegenContext): String = { + row: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val structToStringCode = st.zipWithIndex.map { case (ft, i) => val fieldToStringCode = castToStringCode(ft, ctx) - val field = ctx.freshName("field") - val fieldStr = ctx.freshName("fieldStr") - s""" - |${if (i != 0) s"""$buffer.append(",");""" else ""} + val field = ctx.freshVariable("field", ft) + val fieldStr = ctx.freshVariable("fieldStr", StringType) + val javaType = JavaCode.javaType(ft) + code""" + |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} |if (!$row.isNullAt($i)) { - | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} | | // Append $i field into the string buffer - | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; + | $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -785,11 +826,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } val writeStructCode = ctx.splitExpressions( - expressions = structToStringCode, + expressions = structToStringCode.map(_.code), funcName = "fieldToString", - arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + arguments = ("InternalRow", row.code) :: + (classOf[UTF8StringBuilder].getName, buffer.code) :: Nil) - s""" + code""" |$buffer.append("["); |$writeStructCode |$buffer.append("]"); @@ -799,20 +841,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromBytes($c);" case DateType => - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ArrayType(et, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeArrayElemCode; |$evPrim = $buffer.build(); @@ -820,10 +862,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case MapType(kt, vt, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeMapElemCode; |$evPrim = $buffer.build(); @@ -831,11 +873,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case StructType(fields) => (c, evPrim, evNull) => { - val row = ctx.freshName("row") - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val row = ctx.freshVariable("row", classOf[InternalRow]) + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) - s""" + code""" |InternalRow $row = $c; |$bufferClass $buffer = new $bufferClass(); |$writeStructCode @@ -844,26 +886,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => - val udtRef = ctx.addReferenceObj("udt", udt) + val udtRef = JavaCode.global(ctx.addReferenceObj("udt", udt), udt.sqlType) (c, evPrim, evNull) => { - s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" } case _ => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromString(String.valueOf($c));" } } private[this] def castToBinaryCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + (c, evPrim, evNull) => code"$evPrim = $c.getBytes();" } private[this] def castToDateCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val intOpt = ctx.freshName("intOpt") - (c, evPrim, evNull) => s""" + val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]]) + (c, evPrim, evNull) => code""" scala.Option $intOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); if ($intOpt.isDefined()) { @@ -873,16 +915,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);""" case _ => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" } - private[this] def changePrecision(d: String, decimalType: DecimalType, - evPrim: String, evNull: String): String = - s""" + private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, + evPrim: ExprValue, evNull: ExprValue): Block = + code""" if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { $evPrim = $d; } else { @@ -894,11 +937,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { - val tmp = ctx.freshName("tmpDecimal") + val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); ${changePrecision(tmp, target, evPrim, evNull)} @@ -908,37 +951,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case BooleanType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); ${changePrecision(tmp, target, evPrim, evNull)} """ case DateType => // date can't cast to decimal in Hive - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => // Note that we lose precision here. (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply( scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); ${changePrecision(tmp, target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c.clone(); ${changePrecision(tmp, target, evPrim, evNull)} """ case x: IntegralType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply((long) $c); ${changePrecision(tmp, target, evPrim, evNull)} """ case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); ${changePrecision(tmp, target, evPrim, evNull)} @@ -953,10 +996,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - val longOpt = ctx.freshName("longOpt") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]]) (c, evPrim, evNull) => - s""" + code""" scala.Option $longOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $tz); if ($longOpt.isDefined()) { @@ -966,18 +1009,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case _: IntegralType => - (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${longToTimeStampCode(c)};" case DateType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;""" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${decimalToTimestampCode(c)};" case DoubleType => (c, evPrim, evNull) => - s""" + code""" if (Double.isNaN($c) || Double.isInfinite($c)) { $evNull = true; } else { @@ -986,7 +1030,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case FloatType => (c, evPrim, evNull) => - s""" + code""" if (Float.isNaN($c) || Float.isInfinite($c)) { $evNull = true; } else { @@ -998,7 +1042,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"""$evPrim = CalendarInterval.fromString($c.toString()); + code"""$evPrim = CalendarInterval.fromString($c.toString()); if(${evPrim} == null) { ${evNull} = true; } @@ -1006,18 +1050,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } - private[this] def decimalToTimestampCode(d: String): String = - s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" - private[this] def timestampToIntegerCode(ts: String): String = - s"java.lang.Math.floor((double) $ts / 1000000L)" - private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + private[this] def decimalToTimestampCode(d: ExprValue): Block = { + val block = inline"new java.math.BigDecimal(1000000L)" + code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()" + } + private[this] def longToTimeStampCode(l: ExprValue): Block = code"$l * 1000000L" + private[this] def timestampToIntegerCode(ts: ExprValue): Block = + code"java.lang.Math.floor((double) $ts / 1000000L)" + private[this] def timestampToDoubleCode(ts: ExprValue): Block = + code"$ts / 1000000.0" private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => - s""" + code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { @@ -1027,21 +1074,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" case DateType => // Hive would return null when cast from date to boolean - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + (c, evPrim, evNull) => code"$evPrim = !$c.isZero();" case n: NumericType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; @@ -1051,24 +1098,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (byte) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + (c, evPrim, evNull) => code"$evPrim = $c.toByte();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + (c, evPrim, evNull) => code"$evPrim = (byte) $c;" } private[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; @@ -1078,22 +1125,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (short) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + (c, evPrim, evNull) => code"$evPrim = $c.toShort();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (short) $c;" + (c, evPrim, evNull) => code"$evPrim = (short) $c;" } private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; @@ -1103,23 +1150,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (int) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + (c, evPrim, evNull) => code"$evPrim = $c.toInt();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (int) $c;" + (c, evPrim, evNull) => code"$evPrim = (int) $c;" } private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("longWrapper") + val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; @@ -1129,21 +1176,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (long) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + (c, evPrim, evNull) => code"$evPrim = $c.toLong();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (long) $c;" + (c, evPrim, evNull) => code"$evPrim = (long) $c;" } private[this] def castToFloatCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Float.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1151,21 +1198,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (float) $c;" + (c, evPrim, evNull) => code"$evPrim = (float) $c;" } private[this] def castToDoubleCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Double.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1173,31 +1220,32 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (double) $c;" + (c, evPrim, evNull) => code"$evPrim = (double) $c;" } private[this] def castArrayCode( fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) - val arrayClass = classOf[GenericArrayData].getName - val fromElementNull = ctx.freshName("feNull") - val fromElementPrim = ctx.freshName("fePrim") - val toElementNull = ctx.freshName("teNull") - val toElementPrim = ctx.freshName("tePrim") - val size = ctx.freshName("n") - val j = ctx.freshName("j") - val values = ctx.freshName("values") + val arrayClass = JavaCode.javaType(classOf[GenericArrayData]) + val fromElementNull = ctx.freshVariable("feNull", BooleanType) + val fromElementPrim = ctx.freshVariable("fePrim", fromType) + val toElementNull = ctx.freshVariable("teNull", BooleanType) + val toElementPrim = ctx.freshVariable("tePrim", toType) + val size = ctx.freshVariable("n", IntegerType) + val j = ctx.freshVariable("j", IntegerType) + val values = ctx.freshVariable("values", classOf[Array[Object]]) + val javaType = JavaCode.javaType(fromType) (c, evPrim, evNull) => - s""" + code""" final int $size = $c.numElements(); final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { @@ -1205,7 +1253,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${CodeGenerator.javaType(fromType)} $fromElementPrim = + $javaType $fromElementPrim = ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} @@ -1224,23 +1272,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) - val mapClass = classOf[ArrayBasedMapData].getName + val mapClass = JavaCode.javaType(classOf[ArrayBasedMapData]) - val keys = ctx.freshName("keys") - val convertedKeys = ctx.freshName("convertedKeys") - val convertedKeysNull = ctx.freshName("convertedKeysNull") + val keys = ctx.freshVariable("keys", ArrayType(from.keyType)) + val convertedKeys = ctx.freshVariable("convertedKeys", ArrayType(to.keyType)) + val convertedKeysNull = ctx.freshVariable("convertedKeysNull", BooleanType) - val values = ctx.freshName("values") - val convertedValues = ctx.freshName("convertedValues") - val convertedValuesNull = ctx.freshName("convertedValuesNull") + val values = ctx.freshVariable("values", ArrayType(from.valueType)) + val convertedValues = ctx.freshVariable("convertedValues", ArrayType(to.valueType)) + val convertedValuesNull = ctx.freshVariable("convertedValuesNull", BooleanType) (c, evPrim, evNull) => - s""" + code""" final ArrayData $keys = $c.keyArray(); final ArrayData $values = $c.valueArray(); - ${castCode(ctx, keys, "false", + ${castCode(ctx, keys, FalseLiteral, convertedKeys, convertedKeysNull, ArrayType(to.keyType), keysCast)} - ${castCode(ctx, values, "false", + ${castCode(ctx, values, FalseLiteral, convertedValues, convertedValuesNull, ArrayType(to.valueType), valuesCast)} $evPrim = new $mapClass($convertedKeys, $convertedValues); @@ -1253,17 +1301,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } - val rowClass = classOf[GenericInternalRow].getName - val tmpResult = ctx.freshName("tmpResult") - val tmpInput = ctx.freshName("tmpInput") + val tmpResult = ctx.freshVariable("tmpResult", classOf[GenericInternalRow]) + val rowClass = JavaCode.javaType(classOf[GenericInternalRow]) + val tmpInput = ctx.freshVariable("tmpInput", classOf[InternalRow]) val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => - val fromFieldPrim = ctx.freshName("ffp") - val fromFieldNull = ctx.freshName("ffn") - val toFieldPrim = ctx.freshName("tfp") - val toFieldNull = ctx.freshName("tfn") - val fromType = CodeGenerator.javaType(from.fields(i).dataType) - s""" + val fromFieldPrim = ctx.freshVariable("ffp", from.fields(i).dataType) + val fromFieldNull = ctx.freshVariable("ffn", BooleanType) + val toFieldPrim = ctx.freshVariable("tfp", to.fields(i).dataType) + val toFieldNull = ctx.freshVariable("tfn", BooleanType) + val fromType = JavaCode.javaType(from.fields(i).dataType) + val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim) + code""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); @@ -1275,18 +1324,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + $setColumn; } } """ } val fieldsEvalCodes = ctx.splitExpressions( - expressions = fieldsEvalCode, + expressions = fieldsEvalCode.map(_.code), funcName = "castStruct", - arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil) + arguments = ("InternalRow", tmpInput.code) :: (rowClass.code, tmpResult.code) :: Nil) (input, result, resultIsNull) => - s""" + code""" final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpInput = $input; $fieldsEvalCodes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala new file mode 100644 index 0000000000000..07fa813a98922 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Defines values for `SQLConf` config of fallback mode. Use for test only. + */ +object CodegenObjectFactoryMode extends Enumeration { + val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value +} + +/** + * A codegen object generator which creates objects with codegen path first. Once any compile + * error happens, it can fallback to interpreted implementation. In tests, we can use a SQL config + * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. + */ +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging { + + def createObject(in: IN): OUT = { + // We are allowed to choose codegen-only or no-codegen modes if under tests. + val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE) + val fallbackMode = CodegenObjectFactoryMode.withName(config) + + fallbackMode match { + case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting => + createCodeGeneratedObject(in) + case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting => + createInterpretedObject(in) + case _ => + try { + createCodeGeneratedObject(in) + } catch { + case NonFatal(_) => + // We should have already seen the error message in `CodeGenerator` + logWarning("Expr codegen error and falling back to interpreter mode") + createInterpretedObject(in) + } + } + } + + protected def createCodeGeneratedObject(in: IN): OUT + protected def createInterpretedObject(in: IN): OUT +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 98f25a9ad7597..981ce0b6a29fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types.AbstractDataType * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define * expected input types without any implicit casting. * - * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead. + * Most function expressions (e.g. [[Substring]] should extend [[ImplicitCastInputTypes]]) instead. */ trait ExpectsInputTypes extends Expression { @@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression { def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - val mismatches = children.zip(inputTypes).zipWithIndex.collect { - case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + ExpectsInputTypes.checkInputDataTypes(children, inputTypes) + } +} + +object ExpectsInputTypes { + + def checkInputDataTypes( + inputs: Seq[Expression], + inputTypes: Seq[AbstractDataType]): TypeCheckResult = { + val mismatches = inputs.zip(inputTypes).zipWithIndex.collect { + case ((input, expected), idx) if !expected.acceptsType(input.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.sql}' is of ${child.dataType.simpleString} type." + s"however, '${input.sql}' is of ${input.dataType.catalogString} type." } if (mismatches.isEmpty) { @@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression { } } - /** * A mixin for the analyzer to perform implicit type casting using * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6ae88299..773aefc0ac1f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] { JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) - if (eval.code.nonEmpty) { + if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) + eval.copy(code = ctx.registerComment(this.toString) + eval.code) } else { eval } @@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] { val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | $setIsNull | return ${eval.value}; |} """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression { if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -581,10 +580,10 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { // First check whether left and right have the same type, then check if the type is acceptable. if (!left.dataType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + - s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + s"(${left.dataType.catalogString} and ${right.dataType.catalogString}).") } else if (!inputType.acceptsType(left.dataType)) { TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," + - s" not ${left.dataType.simpleString}") + s" not ${left.dataType.catalogString}") } else { TypeCheckResult.TypeCheckSuccess } @@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} @@ -697,6 +695,36 @@ abstract class TernaryExpression extends Expression { } } +/** + * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. + * This logic is usually utilized by expressions combining data from multiple child expressions + * of non-primitive types (e.g. [[CaseWhen]]). + */ +trait ComplexTypeMergingExpression extends Expression { + + /** + * A collection of data types used for resolution the output type of the expression. By default, + * data types of all child expressions. The collection must not be empty. + */ + @transient + lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) + + def dataTypeCheck: Unit = { + require( + inputTypesForMerging.nonEmpty, + "The collection of input data types must not be empty.") + require( + TypeCoercion.haveSameType(inputTypesForMerging), + "All input types must be the same except nullable, containsNull, valueContainsNull flags." + + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") + } + + override def dataType: DataType = { + dataTypeCheck + inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) + } +} + /** * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages * and Hive function wrappers. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 6d69d69b1c802..55a5bd380859e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** * Helper functions for creating an [[InterpretedUnsafeProjection]]. */ -object InterpretedUnsafeProjection extends UnsafeProjectionCreator { - +object InterpretedUnsafeProjection { /** * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. */ - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + def createProjection(exprs: Seq[Expression]): UnsafeProjection = { // We need to make sure that we do not reuse stateful expressions. val cleanedExpressions = exprs.map(_.transform { case s: Stateful => s.freshCopy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index ad1e7bdb31987..f1da592a76845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} /** @@ -38,6 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType} puts the partition ID in the upper 31 bits, and the lower 33 bits represent the record number within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + The function is non-deterministic because its result depends on partition IDs. """) case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { @@ -71,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 3cd73682188bc..226a4ddcffaa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.types.{DataType, StructType} @@ -108,7 +110,32 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -trait UnsafeProjectionCreator { +/** + * The factory object for `UnsafeProjection`. + */ +object UnsafeProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + + protected def toBoundExprs( + exprs: Seq[Expression], + inputSchema: Seq[Attribute]): Seq[Expression] = { + exprs.map(BindReferences.bindReference(_, inputSchema)) + } + + protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { + exprs.map(_ transform { + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + } + /** * Returns an UnsafeProjection for given StructType. * @@ -129,10 +156,7 @@ trait UnsafeProjectionCreator { * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - val unsafeExprs = exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - createProjection(unsafeExprs) + createObject(toUnsafeExprs(exprs)) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -142,34 +166,27 @@ trait UnsafeProjectionCreator { * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(exprs.map(BindReferences.bindReference(_, inputSchema))) - } - - /** - * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. - */ - protected def createProjection(exprs: Seq[Expression]): UnsafeProjection -} - -object UnsafeProjection extends UnsafeProjectionCreator { - - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + create(toBoundExprs(exprs, inputSchema)) } /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * TODO: refactor the plumbing and clean this up. + * The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example, + * when fallbacking to interpreted execution, it is not supported. */ def create( exprs: Seq[Expression], inputSchema: Seq[Attribute], subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - val e = exprs.map(BindReferences.bindReference(_, inputSchema)) - .map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema)) + try { + GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) + } catch { + case NonFatal(_) => + // We should have already seen the error message in `CodeGenerator` + logWarning("Expr codegen error and falling back to interpreter mode") + InterpretedUnsafeProjection.createProjection(unsafeExprs) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index efd664dde725a..6530b176968f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -34,10 +34,14 @@ object PythonUDF { e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) } - def isGroupAggPandasUDF(e: Expression): Boolean = { + def isGroupedAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } + + // This is currently same as GroupedAggPandasUDF, but we might support new types in the future, + // e.g, N -> N transform. + def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e869258469a97..8954fe8a58e6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType /** @@ -38,6 +39,7 @@ import org.apache.spark.sql.types.DataType * @param nullable True if the UDF can return null value. * @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result * each time it is invoked with a particular input. + * @param nullableTypes which of the inputTypes are nullable (i.e. not primitive) */ case class ScalaUDF( function: AnyRef, @@ -46,7 +48,8 @@ case class ScalaUDF( inputTypes: Seq[DataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, - udfDeterministic: Boolean = true) + udfDeterministic: Boolean = true, + nullableTypes: Seq[Boolean] = Nil) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { // The constructor for SPARK 2.1 and 2.2 @@ -57,7 +60,8 @@ case class ScalaUDF( inputTypes: Seq[DataType], udfName: Option[String]) = { this( - function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true) + function, dataType, children, inputTypes, udfName, nullable = true, + udfDeterministic = true, nullableTypes = Nil) } override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) @@ -1030,7 +1034,7 @@ case class ScalaUDF( """.stripMargin ev.copy(code = - s""" + code""" |$evalCode |${initArgs.mkString("\n")} |$callFunc @@ -1047,8 +1051,9 @@ case class ScalaUDF( lazy val udfErrorMessage = { val funcCls = function.getClass.getSimpleName - val inputTypes = children.map(_.dataType.simpleString).mkString(", ") - s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})" + val inputTypes = children.map(_.dataType.catalogString).mkString(", ") + val outputType = dataType.catalogString + s"Failed to execute user defined function($funcCls: ($inputTypes) => $outputType)" } override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f714905..536276b5cb29f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ abstract sealed class SortDirection { @@ -71,7 +73,7 @@ case class SortOrder( if (RowOrdering.isOrderable(dataType)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}") + TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.catalogString}") } } @@ -147,7 +149,41 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { (!child.isAscending && child.nullOrdering == NullsLast) } - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + private lazy val calcPrefix: Any => Long = child.child.dataType match { + case BooleanType => (raw) => + if (raw.asInstanceOf[Boolean]) 1 else 0 + case DateType | TimestampType | _: IntegralType => (raw) => + raw.asInstanceOf[java.lang.Number].longValue() + case FloatType | DoubleType => (raw) => { + val dVal = raw.asInstanceOf[java.lang.Number].doubleValue() + DoublePrefixComparator.computePrefix(dVal) + } + case StringType => (raw) => + StringPrefixComparator.computePrefix(raw.asInstanceOf[UTF8String]) + case BinaryType => (raw) => + BinaryPrefixComparator.computePrefix(raw.asInstanceOf[Array[Byte]]) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + _.asInstanceOf[Decimal].toUnscaledLong + case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => + val p = Decimal.MAX_LONG_DIGITS + val s = p - (dt.precision - dt.scale) + (raw) => { + val value = raw.asInstanceOf[Decimal] + if (value.changePrecision(p, s)) value.toUnscaledLong else Long.MinValue + } + case dt: DecimalType => (raw) => + DoublePrefixComparator.computePrefix(raw.asInstanceOf[Decimal].toDouble) + case _ => (Any) => 0L + } + + override def eval(input: InternalRow): Any = { + val value = child.child.eval(input) + if (value == null) { + null + } else { + calcPrefix(value) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childCode = child.child.genCode(ctx) @@ -181,7 +217,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } ev.copy(code = childCode.code + - s""" + code""" |long ${ev.value} = 0L; |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 787bcaf5e81de..9856b37e53fbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { val idTerm = "partitionId" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 6c4a3601c1730..8e48856d4607c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -79,16 +80,13 @@ case class TimeWindow( if (slideDuration <= 0) { return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.") } - if (startTime < 0) { - return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.") - } if (slideDuration > windowDuration) { return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" + s" to the windowDuration ($windowDuration).") } - if (startTime >= slideDuration) { - return TypeCheckFailure(s"The start time ($startTime) must be less than the " + - s"slideDuration ($slideDuration).") + if (startTime.abs >= slideDuration) { + return TypeCheckFailure(s"The absolute value of start time ($startTime) must be less " + + s"than the slideDuration ($slideDuration).") } } dataTypeCheck @@ -164,7 +162,7 @@ case class PreciseTimestampConversion( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + - s"""boolean ${ev.isNull} = ${eval.isNull}; + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index d4421ca20a9bd..f96a087972f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -63,11 +63,11 @@ case class ApproxCountDistinctForIntervals( } // Mark as lazy so that endpointsExpression is not evaluated during tree transformation. - lazy val endpoints: Array[Double] = - (endpointsExpression.dataType, endpointsExpression.eval()) match { - case (ArrayType(elementType, _), arrayData: ArrayData) => - arrayData.toObjectArray(elementType).map(_.toString.toDouble) - } + lazy val endpoints: Array[Double] = { + val endpointsType = endpointsExpression.dataType.asInstanceOf[ArrayType] + val endpoints = endpointsExpression.eval().asInstanceOf[ArrayData] + endpoints.toObjectArray(endpointsType.elementType).map(_.toString.toDouble) + } override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index a45854a3b5146..c790d87492c73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -132,7 +132,7 @@ case class ApproximatePercentile( case TimestampType => value.asInstanceOf[Long].toDouble case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") + throw new UnsupportedOperationException(s"Unexpected data type ${other.catalogString}") } buffer.add(doubleValue) } @@ -157,7 +157,7 @@ case class ApproximatePercentile( case DoubleType => doubleResult case _: DecimalType => doubleResult.map(Decimal(_)) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") + throw new UnsupportedOperationException(s"Unexpected data type ${other.catalogString}") } if (result.length == 0) { null @@ -206,27 +206,15 @@ object ApproximatePercentile { * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. - * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the - * underlying quantileSummaries is compressed. */ - class PercentileDigest( - private var summaries: QuantileSummaries, - private var isCompressed: Boolean) { - - // Trigger compression if the QuantileSummaries's buffer length exceeds - // compressThresHoldBufferLength. The buffer length can be get by - // quantileSummaries.sampled.length - private[this] final val compressThresHoldBufferLength: Int = { - // Max buffer length after compression. - val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 - // A safe upper bound for buffer length before compression - maxBufferLengthAfterCompression * 2 - } + class PercentileDigest(private var summaries: QuantileSummaries) { def this(relativeError: Double) = { - this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) } + private[sql] def isCompressed: Boolean = summaries.compressed + /** Returns compressed object of [[QuantileSummaries]] */ def quantileSummaries: QuantileSummaries = { if (!isCompressed) compress() @@ -236,14 +224,6 @@ object ApproximatePercentile { /** Insert an observation value into the PercentileDigest data structure. */ def add(value: Double): Unit = { summaries = summaries.insert(value) - // The result of QuantileSummaries.insert is un-compressed - isCompressed = false - - // Currently, QuantileSummaries ignores the construction parameter compressThresHold, - // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here - // to make sure QuantileSummaries doesn't occupy infinite memory. - // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold - if (summaries.sampled.length >= compressThresHoldBufferLength) compress() } /** In-place merges in another PercentileDigest. */ @@ -280,7 +260,6 @@ object ApproximatePercentile { private final def compress(): Unit = { summaries = summaries.compress() - isCompressed = true } } @@ -335,8 +314,8 @@ object ApproximatePercentile { sampled(i) = Stats(value, g, delta) i += 1 } - val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) - new PercentileDigest(summary, isCompressed = true) + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true) + new PercentileDigest(summary) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 708bdbfc36058..5ecb77be5965e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,30 +17,18 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def prettyName: String = "avg" - - override def children: Seq[Expression] = child :: Nil +abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override def nullable: Boolean = true - // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") - private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) @@ -58,18 +46,10 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override lazy val aggBufferAttributes = sum :: count :: Nil override lazy val initialValues = Seq( - /* sum = */ Cast(Literal(0), sumDataType), + /* sum = */ Literal(0).cast(sumDataType), /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* sum = */ - Add( - sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* count = */ If(IsNull(child), count, count + 1L) - ) - override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right @@ -77,12 +57,34 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit // If all input are nulls, count will be 0 and we will get null after the division. override lazy val evaluateExpression = child.dataType match { - case DecimalType.Fixed(p, s) => - // increase the precision and scale to prevent precision loss - val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), - resultType) + case _: DecimalType => + DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) case _ => - Cast(sum, resultType) / Cast(count, resultType) + sum.cast(resultType) / count.cast(resultType) } + + protected def updateExpressionsDef: Seq[Expression] = Seq( + /* sum = */ + Add( + sum, + coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))), + /* count = */ If(child.isNull, count, count + 1L) + ) + + override lazy val updateExpressions = updateExpressionsDef +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") +case class Average(child: Expression) + extends AverageLike(child) with ImplicitCastInputTypes { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 572d29caf5bc9..e2ff0efba07ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression) override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val delta = child - avg - val deltaN = delta / newN - val newAvg = avg + deltaN - val newM2 = m2 + delta * (delta - deltaN) - - val delta2 = delta * delta - val deltaN2 = deltaN * deltaN - val newM3 = if (momentOrder >= 3) { - m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) - } else { - Literal(0.0) - } - val newM4 = if (momentOrder >= 4) { - m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + - delta * (delta * delta2 - deltaN * deltaN2) - } else { - Literal(0.0) - } - - trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) - )) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -103,7 +75,7 @@ abstract class CentralMomentAgg(child: Expression) val n2 = n.right val newN = n1 + n2 val delta = avg.right - avg.left - val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) + val deltaN = If(newN === 0.0, 0.0, delta / newN) val newAvg = avg.left + deltaN * n2 // higher order moments computed according to: @@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression) trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + 1.0 + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(child.isNull, n, newN), + If(child.isNull, avg, newAvg), + If(child.isNull, m2, newM2), + If(child.isNull, m3, newM3), + If(child.isNull, m4, newM4) + )) + } } // Compute the population standard deviation of a column @@ -140,8 +142,7 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - Sqrt(m2 / n)) + If(n === 0.0, Literal.create(null, DoubleType), sqrt(m2 / n)) } override def prettyName: String = "stddev_pop" @@ -157,9 +158,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - Sqrt(m2 / (n - Literal(1.0))))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } override def prettyName: String = "stddev_samp" @@ -173,8 +173,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - m2 / n) + If(n === 0.0, Literal.create(null, DoubleType), m2 / n) } override def prettyName: String = "var_pop" @@ -188,9 +187,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - m2 / (n - Literal(1.0)))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } override def prettyName: String = "var_samp" @@ -205,9 +203,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 3 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(m2 === Literal(0.0), Literal(Double.NaN), - Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) + If(n === 0.0, Literal.create(null, DoubleType), + If(m2 === 0.0, Double.NaN, sqrt(n) * m3 / sqrt(m2 * m2 * m2))) } } @@ -218,9 +215,8 @@ case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 4 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(m2 === Literal(0.0), Literal(Double.NaN), - n * m4 / (m2 * m2) - Literal(3.0))) + If(n === 0.0, Literal.create(null, DoubleType), + If(m2 === 0.0, Double.NaN, n * m4 / (m2 * m2) - 3.0)) } override def prettyName: String = "kurtosis" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 95a4a0d5af634..e14cc716ea223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ /** - * Compute Pearson correlation between two expressions. + * Base class for computing Pearson correlation between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. * * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") -// scalastyle:on line.size.limit -case class Corr(x: Expression, y: Expression) +abstract class PearsonCorrelation(x: Expression, y: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) @@ -51,8 +47,27 @@ case class Corr(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef + + override val mergeExpressions: Seq[Expression] = { + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === 0.0, 0.0, dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === 0.0, 0.0, dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) + } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + 1.0 val dx = x - xAvg val dxN = dx / newN val dy = y - yAvg @@ -63,7 +78,7 @@ case class Corr(x: Expression, y: Expression) val newXMk = xMk + dx * (x - newXAvg) val newYMk = yMk + dy * (y - newYAvg) - val isNull = IsNull(x) || IsNull(y) + val isNull = x.isNull || y.isNull Seq( If(isNull, n, newN), If(isNull, xAvg, newXAvg), @@ -73,29 +88,19 @@ case class Corr(x: Expression, y: Expression) If(isNull, yMk, newYMk) ) } +} - override val mergeExpressions: Seq[Expression] = { - - val n1 = n.left - val n2 = n.right - val newN = n1 + n2 - val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) - val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) - val newXAvg = xAvg.left + dxN * n2 - val newYAvg = yAvg.left + dyN * n2 - val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 - val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 - Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) - } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") +// scalastyle:on line.size.limit +case class Corr(x: Expression, y: Expression) + extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - ck / Sqrt(xMk * yMk))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, ck / sqrt(xMk * yMk))) } override def prettyName: String = "corr" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 1990f2f2f0722..40582d0abd762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. - - _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. - - _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. - """) -// scalastyle:on line.size.limit -case class Count(children: Seq[Expression]) extends DeclarativeAggregate { - +/** + * Base class for all counting aggregators. + */ +abstract class CountLike extends DeclarativeAggregate { override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType - private lazy val count = AttributeReference("count", LongType, nullable = false)() + protected lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) + override lazy val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override lazy val evaluateExpression = count + + override def defaultResult: Option[Literal] = Option(Literal(0L)) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. + + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. + + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. + """) +// scalastyle:on line.size.limit +case class Count(children: Seq[Expression]) extends CountLike { + override lazy val updateExpressions = { val nullableChildren = children.filter(_.nullable) if (nullableChildren.isEmpty) { @@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { ) } } - - override lazy val mergeExpressions = Seq( - /* count = */ count.left + count.right - ) - - override lazy val evaluateExpression = count - - override def defaultResult: Option[Literal] = Option(Literal(0L)) } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index fc6c34baafdd1..ee28eb591882f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) - override lazy val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val dx = x - xAvg - val dy = y - yAvg - val dyN = dy / newN - val newXAvg = xAvg + dx / newN - val newYAvg = yAvg + dyN - val newCk = ck + dx * (y - newYAvg) - - val isNull = IsNull(x) || IsNull(y) - Seq( - If(isNull, n, newN), - If(isNull, xAvg, newXAvg), - If(isNull, yAvg, newYAvg), - If(isNull, ck, newCk) - ) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -66,23 +50,40 @@ abstract class Covariance(x: Expression, y: Expression) val n2 = n.right val newN = n1 + n2 val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dxN = If(newN === 0.0, 0.0, dx / newN) val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val dyN = If(newN === 0.0, 0.0, dy / newN) val newXAvg = xAvg.left + dxN * n2 val newYAvg = yAvg.left + dyN * n2 val newCk = ck.left + ck.right + dx * dyN * n1 * n2 Seq(newN, newXAvg, newYAvg, newCk) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + 1.0 + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = x.isNull || y.isNull + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) + } } @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns the population covariance of a set of number pairs.") case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - ck / n) + If(n === 0.0, Literal.create(null, DoubleType), ck / n) } override def prettyName: String = "covar_pop" } @@ -92,9 +93,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance usage = "_FUNC_(expr1, expr2) - Returns the sample covariance of a set of number pairs.") case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - ck / (n - Literal(1.0)))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, ck / (n - 1.0))) } override def prettyName: String = "covar_samp" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 4e671e1f3e6eb..f51bfd591204a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* first = */ If(Or(valueSet, IsNull(child)), first, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) + /* first = */ If(valueSet || child.isNull, first, child), + /* valueSet = */ valueSet || child.isNotNull ) } else { Seq( @@ -97,7 +98,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) // false, we are safe to do so because first.right will be null in this case). Seq( /* first = */ If(valueSet.left, first.left, first.right), - /* valueSet = */ Or(valueSet.left, valueSet.right) + /* valueSet = */ valueSet.left || valueSet.right ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 0ccabb9d98914..2650d7b5908fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* last = */ If(IsNull(child), last, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) + /* last = */ If(child.isNull, last, child), + /* valueSet = */ valueSet || child.isNotNull ) } else { Seq( @@ -95,7 +96,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) // Prefer the right hand expression if it has been set. Seq( /* last = */ If(valueSet.right, last.right, last.left), - /* valueSet = */ Or(valueSet.right, valueSet.left) + /* valueSet = */ valueSet.right || valueSet.left ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 58fd1d8620e16..71099eba0fc75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -45,12 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* max = */ Greatest(Seq(max, child)) + /* max = */ greatest(max, child) ) override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* max = */ Greatest(Seq(max.left, max.right)) + /* max = */ greatest(max.left, max.right) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index b2724ee76827c..8c4ba93231cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -45,12 +46,12 @@ case class Min(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* min = */ Least(Seq(min, child)) + /* min = */ least(min, child) ) override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* min = */ Least(Seq(min.left, min.right)) + /* min = */ least(min.left, min.right) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 523714869242d..33bc5b5821b36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import scala.collection.immutable.HashMap +import scala.collection.immutable.{HashMap, TreeMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ object PivotFirst { @@ -83,7 +83,12 @@ case class PivotFirst( override val dataType: DataType = ArrayType(valueDataType) - val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) + val pivotIndex = if (pivotColumn.dataType.isInstanceOf[AtomicType]) { + HashMap(pivotColumnValues.zipWithIndex: _*) + } else { + TreeMap(pivotColumnValues.zipWithIndex: _*)( + TypeUtils.getInterpretedOrdering(pivotColumn.dataType)) + } val indexSize = pivotIndex.size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 86e40a9713b36..761dba111c074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -61,12 +62,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast if (child.nullable) { Seq( /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) ) } else { Seq( /* sum = */ - Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + coalesce(sum, zero) + child.cast(sumDataType) ) } } @@ -74,7 +75,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) + coalesce(coalesce(sum.left, zero) + sum.right, sum.left) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala new file mode 100644 index 0000000000000..d8f4505588ff2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{AbstractDataType, DoubleType} + +/** + * Base trait for all regression functions. + */ +trait RegrLike extends AggregateFunction with ImplicitCastInputTypes { + def y: Expression + def x: Expression + + override def children: Seq[Expression] = Seq(y, x) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = { + assert(aggBufferAttributes.length == exprs.length) + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + exprs + } else { + exprs.zip(aggBufferAttributes).map { case (e, a) => + If(nullableChildren.map(IsNull).reduce(Or), a, e) + } + } + } +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the number of non-null pairs.", + since = "2.4.0") +case class RegrCount(y: Expression, x: Expression) + extends CountLike with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L)) + + override def prettyName: String = "regr_count" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSXX(y: Expression, x: Expression) + extends CentralMomentAgg(x) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_sxx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSYY(y: Expression, x: Expression) + extends CentralMomentAgg(y) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_syy" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgX(y: Expression, x: Expression) + extends AverageLike(x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgY(y: Expression, x: Expression) + extends AverageLike(y) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgy" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSXY(y: Expression, x: Expression) + extends Covariance(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), ck) + } + + override def prettyName: String = "regr_sxy" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSlope(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk) + } + + override def prettyName: String = "regr_slope" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrR2(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk)) + } + + override def prettyName: String = "regr_r2" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrIntercept(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + xAvg - (ck / yMk) * yAvg) + } + + override def prettyName: String = "regr_intercept" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4e322d23b95b..c827226d58420 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -220,30 +221,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", - examples = """ - Examples: - > SELECT 3 _FUNC_ 2; - 1.5 - > SELECT 2L _FUNC_ 2L; - 1.0 - """) -// scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) +// Common base trait for Divide and Remainder, since these two classes are almost identical +trait DivModLike extends BinaryArithmetic { - override def symbol: String = "/" - override def decimalMethod: String = "$div" override def nullable: Boolean = true - private lazy val div: (Any, Any) => Any = dataType match { - case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - } - - override def eval(input: InternalRow): Any = { + final override def eval(input: InternalRow): Any = { val input2 = right.eval(input) if (input2 == null || input2 == 0) { null @@ -252,13 +235,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (input1 == null) { null } else { - div(input1, input2) + evalOperation(input1, input2) } } } + def evalOperation(left: Any, right: Any): Any + /** - * Special case handling due to division by 0 => null. + * Special case handling due to division/remainder by 0 => null. */ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) @@ -269,13 +254,13 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val divide = if (dataType.isInstanceOf[DecimalType]) { + val operation = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -283,10 +268,10 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ${ev.isNull} = true; } else { ${eval1.code} - ${ev.value} = $divide; + ${ev.value} = $operation; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -297,13 +282,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.value} = $divide; + ${ev.value} = $operation; } }""") } } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1.5 + > SELECT 2L _FUNC_ 2L; + 1.0 + """) +// scalastyle:on line.size.limit +case class Divide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) + + override def symbol: String = "/" + override def decimalMethod: String = "$div" + + private lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ @@ -313,82 +323,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic > SELECT MOD(2, 1.8); 0.2 """) -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = NumericType override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true - private lazy val integral = dataType match { - case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] - case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + private lazy val mod: (Any, Any) => Any = dataType match { + // special cases to make float/double primitive types faster + case DoubleType => + (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double] + case FloatType => + (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float] + + // catch-all cases + case i: IntegralType => + val integral = i.integral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) + case i: FractionalType => // should only be DecimalType for now + val integral = i.asIntegral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) } - override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { - null - } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - input1 match { - case d: Double => d % input2.asInstanceOf[java.lang.Double] - case f: Float => f % input2.asInstanceOf[java.lang.Float] - case _ => integral.rem(input1, input2) - } - } - } - } - - /** - * Special case handling for x % 0 ==> null. - */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) - val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" - } else { - s"${eval2.value} == 0" - } - val javaType = CodeGenerator.javaType(dataType) - val remainder = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" - } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" - } - if (!left.nullable && !right.nullable) { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - ${ev.value} = $remainder; - }""") - } else { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { - ${ev.isNull} = true; - } else { - ${ev.value} = $remainder; - } - }""") - } - } + override def evalOperation(left: Any, right: Any): Any = mod(left, right) } @ExpressionDescription( @@ -479,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -490,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { $result }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -556,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { > SELECT _FUNC_(10, 9, 2, 4, 3); 2 """) -case class Least(children: Seq[Expression]) extends Expression { +case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -567,17 +525,15 @@ case class Least(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") + s" got LEAST(${children.map(_.dataType.catalogString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) @@ -612,7 +568,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes @@ -631,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression { > SELECT _FUNC_(10, 9, 2, 4, 3); 10 """) -case class Greatest(children: Seq[Expression]) extends Expression { +case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -642,17 +598,15 @@ case class Greatest(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") + s" got GREATEST(${children.map(_.dataType.catalogString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) @@ -687,7 +641,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 7b398f424cead..ea1bb87d415c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -27,6 +27,10 @@ import java.util.regex.Matcher */ object CodeFormatter { val commentHolder = """\/\*(.+?)\*\/""".r + val commentRegexp = + ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ + """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment + val extraNewLinesRegexp = """\n\s*\n""".r // strip extra newlines def format(code: CodeAndComment, maxLines: Int = -1): String = { val formatter = new CodeFormatter @@ -91,11 +95,7 @@ object CodeFormatter { } def stripExtraNewLinesAndComments(input: String): String = { - val commentReg = - ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ - """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment - val codeWithoutComment = commentReg.replaceAllIn(input, "") - codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines + extraNewLinesRegexp.replaceAllIn(commentRegexp.replaceAllIn(input, ""), "\n") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..b8f09761f61ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,10 +38,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils} @@ -56,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = "", isNull, value) + ExprCode(code = EmptyBlock, isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = "", isNull = FalseLiteral, value = value) + ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value) } } @@ -329,9 +331,9 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => code"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" + case _ => code"$value = $initCode;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } @@ -469,6 +471,8 @@ class CodegenContext { case NewFunctionSpec(functionName, None, None) => functionName case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) => innerClassInstance + "." + functionName + case _ => + throw new IllegalArgumentException(s"$funcName is not matched at addNewFunction") } } @@ -577,6 +581,18 @@ class CodegenContext { s"${fullName}_$id" } + /** + * Creates an `ExprValue` representing a local java variable of required data type. + */ + def freshVariable(name: String, dt: DataType): VariableValue = + JavaCode.variable(freshName(name), dt) + + /** + * Creates an `ExprValue` representing a local java variable of required Java class. + */ + def freshVariable(name: String, javaClass: Class[_]): VariableValue = + JavaCode.variable(freshName(name), javaClass) + /** * Generates code for equal expression in Java. */ @@ -594,7 +610,7 @@ class CodegenContext { case NullType => "false" case _ => throw new IllegalArgumentException( - "cannot generate equality code for un-comparable type: " + dataType.simpleString) + "cannot generate equality code for un-comparable type: " + dataType.catalogString) } /** @@ -681,7 +697,7 @@ class CodegenContext { case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException( - "cannot generate compare code for un-comparable type: " + dataType.simpleString) + "cannot generate compare code for un-comparable type: " + dataType.catalogString) } /** @@ -730,6 +746,73 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. + * + * @param arrayName name of the array to create + * @param numElements code representing the number of elements the array should contain + * @param elementType data type of the elements in the array + * @param additionalErrorMessage string to include in the error message + */ + def createUnsafeArray( + arrayName: String, + numElements: String, + elementType: DataType, + additionalErrorMessage: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + + s""" + |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | ${elementType.defaultSize}); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + + | "$additionalErrorMessage"); + |} + |byte[] $arrayBytes = new byte[(int)$arraySize]; + |UnsafeArrayData $arrayName = new UnsafeArrayData(); + |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + """.stripMargin + } + + /** + * Generates code creating a [[UnsafeArrayData]]. The generated code executes + * a provided fallback when the size of backing array would exceed the array size limit. + * @param arrayName a name of the array to create + * @param numElements a piece of code representing the number of elements the array should contain + * @param elementSize a size of an element in bytes + * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] + * and getting the backing array as a parameter + * @param fallbackCode a piece of code executed when the array size limit is exceeded + */ + def createUnsafeArrayWithFallback( + arrayName: String, + numElements: String, + elementSize: Int, + bodyCode: String => String, + fallbackCode: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + s""" + |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | $elementSize); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | $fallbackCode + |} else { + | final byte[] $arrayBytes = new byte[(int)$arraySize]; + | UnsafeArrayData $arrayName = new UnsafeArrayData(); + | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + | ${bodyCode(arrayBytes)} + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. @@ -750,6 +833,36 @@ class CodegenContext { } } + /** + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ + def nullArrayElementsSaveExec( + nullElements: Boolean, + isNull: String, + arrayData: String)( + execute: String): String = { + val i = freshName("idx") + if (nullElements) { + s""" + |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) { + | $isNull |= $arrayData.isNullAt($i); + |} + |if (!$isNull) { + | $execute + |} + """.stripMargin + } else { + execute + } + } + /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -988,7 +1101,7 @@ class CodegenContext { val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.trim + eval.code.toString } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1016,7 +1129,7 @@ class CodegenContext { val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} + | ${eval.code} | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} @@ -1073,13 +1186,8 @@ class CodegenContext { def registerComment( text: => String, placeholderId: String = "", - force: Boolean = false): String = { - // By default, disable comments in generated code because computing the comments themselves can - // be extremely expensive in certain cases, such as deeply-nested expressions which operate over - // inputs with wide schemas. For more details on the performance issues that motivated this - // flat, see SPARK-15680. - if (force || - SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + force: Boolean = false): Block = { + if (force || SQLConf.get.codegenComments) { val name = if (placeholderId != "") { assert(!placeHolderToComments.contains(placeholderId)) placeholderId @@ -1092,9 +1200,9 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - s"/*$name*/" + code"/*$name*/" } else { - "" + EmptyBlock } } } @@ -1162,7 +1270,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { - // This is the value of HugeMethodLimit in the OpenJDK JVM settings + // This is the default value of HugeMethodLimit in the OpenJDK HotSpot JVM, + // beyond which methods will be rejected from JIT compilation final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 // The max valid length of method parameters in JVM. @@ -1220,7 +1329,7 @@ object CodeGenerator extends Logging { evaluator.setParentClassLoader(parentClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") - evaluator.setDefaultImports(Array( + evaluator.setDefaultImports( classOf[Platform].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, @@ -1235,7 +1344,7 @@ object CodeGenerator extends Logging { classOf[TaskContext].getName, classOf[TaskKilledException].getName, classOf[InputMetrics].getName - )) + ) evaluator.setExtendedClass(classOf[GeneratedClass]) logDebug({ @@ -1289,9 +1398,15 @@ object CodeGenerator extends Logging { try { val cf = new ClassFile(new ByteArrayInputStream(classBytes)) val stats = cf.methodInfos.asScala.flatMap { method => - method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => + method.getAttributes().filter(_.getClass eq codeAttr).map { a => val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) + + if (byteCodeSize > DEFAULT_JVM_HUGE_METHOD_LIMIT) { + logInfo("Generated method too long to be JIT compiled: " + + s"${cf.getThisClassName}.${method.getName} is $byteCodeSize bytes") + } + byteCodeSize } } @@ -1316,7 +1431,7 @@ object CodeGenerator extends Logging { * weak keys/values and thus does not respond to memory pressure. */ private val cache = CacheBuilder.newBuilder() - .maximumSize(100) + .maximumSize(SQLConf.get.codegenCacheMaxEntries) .build( new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { override def load(code: CodeAndComment): (GeneratedClass, Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a91989e129664..3f4704d287cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { val placeHolder = ctx.registerComment(this.toString) val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; @@ -55,7 +56,7 @@ trait CodegenFallback extends Expression { ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 01c350e9dbf69..39778661d1c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) val code = - s""" + code""" |final InternalRow $tmpInput = $input; |final Object[] $values = new Object[${schema.length}]; |$allFields @@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), elementType) - val code = s""" + val code = code""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; @@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) - val code = s""" + val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 01b4d6c4529bd..998a675eecc62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -86,7 +87,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call // `reset` to set up its fixed-size region every time. - if (inputs.map(_.isNull).forall(_ == "false")) { + if (inputs.map(_.isNull).forall(_ == FalseLiteral)) { // If all fields are not nullable, which means the null bits never changes, then we don't // need to clear it out every time. "" @@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = - s""" + code""" |$rowWriter.reset(); |$evalSubexpr |$writeExpressions @@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | } | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | return ${eval.value}; | } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff018488863..17d4a0dc4e884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -112,6 +114,192 @@ object JavaCode { def isNullExpression(code: String): SimpleExprValue = { expression(code, BooleanType) } + + /** + * Create an `Inline` for Java Class name. + */ + def javaType(javaClass: Class[_]): Inline = Inline(javaClass.getName) + + /** + * Create an `Inline` for Java Type name. + */ + def javaType(dataType: DataType): Inline = Inline(CodeGenerator.javaType(dataType)) + + /** + * Create an `Inline` for boxed Java Type name. + */ + def boxedType(dataType: DataType): Inline = Inline(CodeGenerator.boxedType(dataType)) +} + +/** + * A trait representing a block of java code. + */ +trait Block extends TreeNode[Block] with JavaCode { + import Block._ + + // Returns java code string for this code block. + override def toString: String = _marginChar match { + case Some(c) => code.stripMargin(c).trim + case _ => code.trim + } + + def length: Int = toString.length + + def isEmpty: Boolean = toString.isEmpty + + def nonEmpty: Boolean = !isEmpty + + // The leading prefix that should be stripped from each line. + // By default we strip blanks or control characters followed by '|' from the line. + var _marginChar: Option[Char] = Some('|') + + def stripMargin(c: Char): this.type = { + _marginChar = Some(c) + this + } + + def stripMargin: this.type = { + _marginChar = Some('|') + this + } + + /** + * Apply a map function to each java expression codes present in this java code, and return a new + * java code based on the mapped java expression codes. + */ + def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = { + var changed = false + + @inline def transform(e: ExprValue): ExprValue = { + val newE = f lift e + if (!newE.isDefined || newE.get.equals(e)) { + e + } else { + changed = true + newE.get + } + } + + def doTransform(arg: Any): AnyRef = arg match { + case e: ExprValue => transform(e) + case Some(value) => Some(doTransform(value)) + case seq: Traversable[_] => seq.map(doTransform) + case other: AnyRef => other + } + + val newArgs = mapProductIterator(doTransform) + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + + // Concatenates this block with other block. + def + (other: Block): Block = other match { + case EmptyBlock => this + case _ => code"$this\n$other" + } + + override def verboseString: String = toString +} + +object Block { + + val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + + /** + * A custom string interpolator which inlines a string into code block. + */ + implicit class InlineHelper(val sc: StringContext) extends AnyVal { + def inline(args: Any*): Inline = { + val inlineString = sc.raw(args: _*) + Inline(inlineString) + } + } + + implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) + + implicit class BlockHelper(val sc: StringContext) extends AnyVal { + def code(args: Any*): Block = { + sc.checkLengths(args) + if (sc.parts.length == 0) { + EmptyBlock + } else { + args.foreach { + case _: ExprValue | _: Inline | _: Block => + case _: Int | _: Long | _: Float | _: Double | _: String => + case other => throw new IllegalArgumentException( + s"Can not interpolate ${other.getClass.getName} into code block.") + } + + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) + CodeBlock(codeParts, blockInputs) + } + } + } + + // Folds eagerly the literal args into the code parts. + private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = { + val codeParts = ArrayBuffer.empty[String] + val blockInputs = ArrayBuffer.empty[JavaCode] + + val strings = parts.iterator + val inputs = args.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + + buf.append(strings.next) + while (strings.hasNext) { + val input = inputs.next + input match { + case _: ExprValue | _: CodeBlock => + codeParts += buf.toString + buf.clear + blockInputs += input.asInstanceOf[JavaCode] + case EmptyBlock => + case _ => + buf.append(input) + } + buf.append(strings.next) + } + codeParts += buf.toString + + (codeParts.toSeq, blockInputs.toSeq) + } +} + +/** + * A block of java code. Including a sequence of code parts and some inputs to this block. + * The actual java code is generated by embedding the inputs into the code parts. Here we keep + * inputs of `JavaCode` instead of simply folding them as a string of code, because we need to + * track expressions (`ExprValue`) in this code block. We need to be able to manipulate the + * expressions later without changing the behavior of this code block in some applications, e.g., + * method splitting. + */ +case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { + override def children: Seq[Block] = + blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]] + + override lazy val code: String = { + val strings = codeParts.iterator + val inputs = blockInputs.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + buf.append(StringContext.treatEscapes(strings.next)) + while (strings.hasNext) { + buf.append(inputs.next) + buf.append(StringContext.treatEscapes(strings.next)) + } + buf.toString + } +} + +case object EmptyBlock extends Block with Serializable { + override val code: String = "" + override def children: Seq[Block] = Seq.empty +} + +/** + * A piece of java code snippet inlines all types of input arguments into a string without + * tracking any reference of `JavaCode` instances. + */ +case class Inline(codeString: String) extends JavaCode { + override val code: String = codeString } /** @@ -123,10 +311,9 @@ trait ExprValue extends JavaCode { } object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code } - /** * A java expression fragment. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c16793bda028e..cf9796ef1948f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,49 +16,110 @@ */ package org.apache.spark.sql.catalyst.expressions -import java.util.Comparator +import java.util.{Comparator, TimeZone} + +import scala.collection.mutable +import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.collection.OpenHashSet + +/** + * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit + * casting. + */ +trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression + with ImplicitCastInputTypes { + + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => + TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + + s"been two ${ArrayType.simpleString}s with same element type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}]") + } + } +} + /** - * Given an array or map, returns its size. Returns -1 if null. + * Given an array or map, returns total number of elements in it. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the size of an array or a map. Returns -1 if null.", + usage = """ + _FUNC_(expr) - Returns the size of an array or a map. + The function returns -1 if its input is null and spark.sql.legacy.sizeOfNull is set to true. + If spark.sql.legacy.sizeOfNull is set to false, the function returns null for null input. + By default, the spark.sql.legacy.sizeOfNull parameter is set to true. + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a')); 4 + > SELECT _FUNC_(map('a', 1, 'b', 2)); + 2 + > SELECT _FUNC_(NULL); + -1 """) case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + val legacySizeOfNull = SQLConf.get.legacySizeOfNull + override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) - override def nullable: Boolean = false + override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { - -1 + if (legacySizeOfNull) -1 else null } else child.dataType match { case _: ArrayType => value.asInstanceOf[ArrayData].numElements() case _: MapType => value.asInstanceOf[MapData].numElements() + case other => throw new UnsupportedOperationException( + s"The size function doesn't support the operand type ${other.getClass.getCanonicalName}") } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - ev.copy(code = s""" + if (legacySizeOfNull) { + val childGen = child.genCode(ctx) + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = FalseLiteral) + } else { + defineCodeGen(ctx, ev, c => s"($c).numElements()") + } } } @@ -90,6 +151,167 @@ case class MapKeys(child: Expression) override def prettyName: String = "map_keys" } +@ExpressionDescription( + usage = """ + _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); + [[1, 2], [2, 3], [3, 4]] + > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); + [[1, 2, 3], [2, 3, 4]] + """, + since = "2.4.0") +case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) + + @transient override lazy val dataType: DataType = { + val fields = children.zip(arrayElementTypes).zipWithIndex.map { + case ((expr: NamedExpression, elementType), _) => + StructField(expr.name, elementType, nullable = true) + case ((_, elementType), idx) => + StructField(idx.toString, elementType, nullable = true) + } + ArrayType(StructType(fields), containsNull = false) + } + + override def nullable: Boolean = children.exists(_.nullable) + + @transient private lazy val arrayElementTypes = + children.map(_.dataType.asInstanceOf[ArrayType].elementType) + + private def genericArrayData = classOf[GenericArrayData].getName + + def emptyInputGenCode(ev: ExprCode): ExprCode = { + ev.copy(code""" + |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]); + |boolean ${ev.isNull} = false; + """.stripMargin) + } + + def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val genericInternalRow = classOf[GenericInternalRow].getName + val arrVals = ctx.freshName("arrVals") + val biggestCardinality = ctx.freshName("biggestCardinality") + + val currentRow = ctx.freshName("currentRow") + val j = ctx.freshName("j") + val i = ctx.freshName("i") + val args = ctx.freshName("args") + + val evals = children.map(_.genCode(ctx)) + val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) => + s""" + |if ($biggestCardinality != -1) { + | ${eval.code} + | if (!${eval.isNull}) { + | $arrVals[$index] = ${eval.value}; + | $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements()); + | } else { + | $biggestCardinality = -1; + | } + |} + """.stripMargin + } + + val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs( + expressions = getValuesAndCardinalities, + funcName = "getValuesAndCardinalities", + returnType = "int", + makeSplitFunction = body => + s""" + |$body + |return $biggestCardinality; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), + extraArguments = + ("ArrayData[]", arrVals) :: + ("int", biggestCardinality) :: Nil) + + val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) => + val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i) + s""" + |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) { + | $currentRow[$idx] = $g; + |} else { + | $currentRow[$idx] = null; + |} + """.stripMargin + } + + val getValueForTypeSplitted = ctx.splitExpressions( + expressions = getValueForType, + funcName = "extractValue", + arguments = + ("int", i) :: + ("Object[]", currentRow) :: + ("ArrayData[]", arrVals) :: Nil) + + val initVariables = s""" + |ArrayData[] $arrVals = new ArrayData[${children.length}]; + |int $biggestCardinality = 0; + |${CodeGenerator.javaType(dataType)} ${ev.value} = null; + """.stripMargin + + ev.copy(code""" + |$initVariables + |$splittedGetValuesAndCardinalities + |boolean ${ev.isNull} = $biggestCardinality == -1; + |if (!${ev.isNull}) { + | Object[] $args = new Object[$biggestCardinality]; + | for (int $i = 0; $i < $biggestCardinality; $i ++) { + | Object[] $currentRow = new Object[${children.length}]; + | $getValueForTypeSplitted + | $args[$i] = new $genericInternalRow($currentRow); + | } + | ${ev.value} = new $genericArrayData($args); + |} + """.stripMargin) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (children.length == 0) { + emptyInputGenCode(ev) + } else { + nonEmptyInputGenCode(ctx, ev) + } + } + + override def eval(input: InternalRow): Any = { + val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) + if (inputArrays.contains(null)) { + null + } else { + val biggestCardinality = if (inputArrays.isEmpty) { + 0 + } else { + inputArrays.map(_.numElements()).max + } + + val result = new Array[InternalRow](biggestCardinality) + val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex + + for (i <- 0 until biggestCardinality) { + val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) => + if (i < arr.numElements() && !arr.isNullAt(i)) { + arr.get(i, arrayElementTypes(index)) + } else { + null + } + } + + result(i) = InternalRow.apply(currentLayer: _*) + } + new GenericArrayData(result) + } + } + + override def prettyName: String = "arrays_zip" +} + /** * Returns an unordered array containing the values of the map. */ @@ -119,249 +341,1130 @@ case class MapValues(child: Expression) } /** - * Sorts the input array in ascending / descending order according to the natural ordering of - * the array elements and returns it. + * Returns an unordered array of all entries in the given map. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.", + usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", examples = """ Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true); - ["a","b","c","d"] - """) -// scalastyle:on line.size.limit -case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + > SELECT _FUNC_(map(1, 'a', 2, 'b')); + [(1,"a"),(2,"b")] + """, + since = "2.4.0") +case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { - def this(e: Expression) = this(e, Literal(true)) + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + @transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - "Sort order in second argument requires a boolean literal.") - } - case ArrayType(dt, _) => - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type ${dt.simpleString}") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + override def dataType: DataType = { + ArrayType( + StructType( + StructField("key", childDataType.keyType, false) :: + StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: + Nil), + false) } - @transient - private lazy val lt: Comparator[Any] = { - val ordering = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] - } - - new Comparator[Any]() { - override def compare(o1: Any, o2: Any): Int = { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - -1 - } else if (o2 == null) { - 1 - } else { - ordering.compare(o1, o2) - } - } + override protected def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val length = childMap.numElements() + val resultData = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val key = keys.get(i, childDataType.keyType) + val value = values.get(i, childDataType.valueType) + val row = new GenericInternalRow(Array[Any](key, value)) + resultData.update(i, row) + i += 1 } + new GenericArrayData(resultData) } - @transient - private lazy val gt: Comparator[Any] = { - val ordering = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] - } - - new Comparator[Any]() { - override def compare(o1: Any, o2: Any): Int = { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - 1 - } else if (o2 == null) { - -1 - } else { - -ordering.compare(o1, o2) - } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) } - } - } - - override def nullSafeEval(array: Any, ascending: Any): Any = { - val elementType = base.dataType.asInstanceOf[ArrayType].elementType - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementType != NullType) { - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) - } - new GenericArrayData(data.asInstanceOf[Array[Any]]) + s""" + |final int $numElements = $c.numElements(); + |final ArrayData $keys = $c.keyArray(); + |final ArrayData $values = $c.valueArray(); + |$code + """.stripMargin + }) } - override def prettyName: String = "sort_array" -} - -/** - * Returns a reversed string or an array with reverse order of elements. - */ -@ExpressionDescription( - usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", - examples = """ - Examples: - > SELECT _FUNC_('Spark SQL'); - LQS krapS - > SELECT _FUNC_(array(2, 1, 4, 3)); - [3, 4, 1, 2] - """, - since = "1.5.0", - note = "Reverse logic for arrays is available since 2.4.0." -) -case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - // Input types are utilized by type coercion in ImplicitTypeCasts. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) - - override def dataType: DataType = child.dataType - - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") - override def nullSafeEval(input: Any): Any = input match { - case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) - case s: UTF8String => s.reverse() - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => dataType match { - case _: StringType => stringCodeGen(ev, c) - case _: ArrayType => arrayCodeGen(ctx, ev, c) - }) + private def getValue(varName: String) = { + CodeGenerator.getValue(varName, childDataType.valueType, "z") } - private def stringCodeGen(ev: ExprCode, childName: String): String = { - s"${ev.value} = ($childName).reverse();" - } + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val unsafeRow = ctx.freshName("unsafeRow") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" - private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { - val length = ctx.freshName("length") - val javaElementType = CodeGenerator.javaType(elementType) - val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val structSizeAsLong = structSize + "L" + val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.valueType) - val initialization = if (isPrimitiveType) { - s"$childName.copy()" + val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignmentChecked = if (childDataType.valueContainsNull) { + s""" + |if ($values.isNullAt(z)) { + | $unsafeRow.setNullAt(1); + |} else { + | $valueAssignment + |} + """.stripMargin } else { - s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + valueAssignment } - val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + val assignmentLoop = (byteArray: String) => + s""" + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + + ctx.createUnsafeArrayWithFallback( + unsafeArrayData, + numElements, + structSize + wordSize, + assignmentLoop, + genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val data = ctx.freshName("internalRowArray") - val swapAssigments = if (isPrimitiveType) { - val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) - val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) - s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); - |boolean isNullAtL = ${ev.value}.isNullAt(l); - |if(!isNullAtK) { - | $javaElementType el = ${getCall("k")}; - | if(!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | } else { - | ${ev.value}.setNullAt(k); - | } - | ${ev.value}.$setFunc(l, el); - |} else if (!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | ${ev.value}.setNullAt(l); - |}""".stripMargin + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" } else { - s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + getValue(values) } s""" - |final int $length = $childName.numElements(); - |${ev.value} = $initialization; - |for(int k = 0; k < $numberOfIterations; k++) { - | int l = $length - k - 1; - | $swapAssigments + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { + | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); |} + |$arrayData = new $genericArrayClass($data); """.stripMargin } - override def prettyName: String = "reverse" + override def prettyName: String = "map_entries" } /** - * Checks if the array (left) has the element (right) + * Returns the union of all the given maps. */ @ExpressionDescription( - usage = "_FUNC_(array, value) - Returns true if the array contains the value.", + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 2); - true - """) -case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = BooleanType - - override def inputTypes: Seq[AbstractDataType] = right.dataType match { - case NullType => Seq.empty - case _ => left.dataType match { - case n @ ArrayType(element, _) => Seq(n, element) - case _ => Seq.empty - } - } + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + """, since = "2.4.0") +case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def checkInputDataTypes(): TypeCheckResult = { - if (right.dataType == NullType) { - TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") - } else if (!left.dataType.isInstanceOf[ArrayType] - || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + var funcName = s"function $prettyName" + if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( - "Arguments must be an array followed by a value of same type as the array members") + s"input to $funcName should all be of type map, but it's " + + children.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) } } - override def nullable: Boolean = { - left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull - } - - override def nullSafeEval(arr: Any, value: Any): Any = { - var hasNull = false - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == null) { - hasNull = true - } else if (v == value) { - return true - } - ) - if (hasNull) { - null + @transient override lazy val dataType: MapType = { + if (children.isEmpty) { + MapType(StringType, StringType) } else { - false + super.dataType.asInstanceOf[MapType] } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (arr, value) => { - val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(arr, right.dataType, i) - s""" + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + + s"elements due to exceeding the map size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length + } + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(_.genCode(ctx)) + val keyType = dataType.keyType + val valueType = dataType.valueType + val argsName = ctx.freshName("args") + val hasNullName = ctx.freshName("hasNull") + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + + val init = + s""" + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |boolean ${ev.isNull}, $hasNullName = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin + + val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map { + case ((m, true), i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | if (!${m.isNull}) { + | $argsName[$i] = ${m.value}; + | } else { + | $hasNullName = true; + | } + |} + """.stripMargin + case ((m, false), i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + |} + """.stripMargin + } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "getMapConcatInputs", + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNullName; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n") + ) + + val idxName = ctx.freshName("idx") + val numElementsName = ctx.freshName("numElems") + val finKeysName = ctx.freshName("finalKeys") + val finValsName = ctx.freshName("finalValues") + + val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) { + genCodeForPrimitiveArrays(ctx, keyType, false) + } else { + genCodeForNonPrimitiveArrays(ctx, keyType) + } + + val valueConcat = + if (valueType.sameType(keyType) && + !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) { + keyConcat + } else if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, valueType) + } + + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") + + val mapMerge = + s""" + |${ev.isNull} = $hasNullName; + |if (!${ev.isNull}) { + | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; + | long $numElementsName = 0; + | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + | } + | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | } + | $arrayDataClass $finKeysName = $keyConcat($keyArgsName, + | (int) $numElementsName); + | $arrayDataClass $finValsName = $valueConcat($valArgsName, + | (int) $numElementsName); + | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |} + """.stripMargin + + ev.copy( + code = code""" + |$init + |$codes + |$mapMerge + """.stripMargin) + } + + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + val setterCode1 = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} + |);""".stripMargin + + val setterCode = if (checkForNull) { + s""" + |if ($argsName[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode1 + |}""".stripMargin + } else { + setterCode1 + } + + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $setterCode + | $counter++; + | } + | } + | return $arrayData; + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) + } + + override def prettyName: String = "map_concat" +} + +/** + * Returns a map created from the given array of entries. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", + examples = """ + Examples: + > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); + {1:"a",2:"b"} + """, + since = "2.4.0") +case class MapFromEntries(child: Expression) extends UnaryExpression { + + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + case _ => None + } + + @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3 + + override def nullable: Boolean = child.nullable || nullEntries + + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 + + override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { + case Some(_) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + + s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") + } + + override protected def nullSafeEval(input: Any): Any = { + val arrayData = input.asInstanceOf[ArrayData] + val numEntries = arrayData.numElements() + var i = 0 + if(nullEntries) { + while (i < numEntries) { + if (arrayData.isNullAt(i)) return null + i += 1 + } + } + val keyArray = new Array[AnyRef](numEntries) + val valueArray = new Array[AnyRef](numEntries) + i = 0 + while (i < numEntries) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) + i += 1 + } + ArrayBasedMapData(keyArray, valueArray) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numEntries = ctx.freshName("numEntries") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) + } else { + genCodeForAnyElements(ctx, c, ev.value, numEntries) + } + ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { + s""" + |final int $numEntries = $c.numElements(); + |$code + """.stripMargin + } + }) + } + + private def genCodeForAssignmentLoop( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String, + keyAssignment: (String, String) => String, + valueAssignment: (String, String) => String): String = { + val entry = ctx.freshName("entry") + val i = ctx.freshName("idx") + + val nullKeyCheck = if (dataTypeDetails.get._2) { + s""" + |if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + |} + """.stripMargin + } else { + "" + } + + s""" + |for (int $i = 0; $i < $numEntries; $i++) { + | InternalRow $entry = $childVariable.getStruct($i, 2); + | $nullKeyCheck + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} + | ${valueAssignment(entry, i)} + |} + """.stripMargin + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val keySectionSize = ctx.freshName("keySectionSize") + val valueSectionSize = ctx.freshName("valueSectionSize") + val data = ctx.freshName("byteArray") + val unsafeMapData = ctx.freshName("unsafeMapData") + val keyArrayData = ctx.freshName("keyArrayData") + val valueArrayData = ctx.freshName("valueArrayData") + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val keySize = dataType.keyType.defaultSize + val valueSize = dataType.valueType.defaultSize + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" + val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) + + val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" + if (dataType.valueContainsNull) { + s""" + |if ($entry.isNullAt(1)) { + | $valueArrayData.setNullAt($idx); + |} else { + | $valueNullUnsafeAssignment + |} + """.stripMargin + } else { + valueNullUnsafeAssignment + } + } + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment + ) + + s""" + |final long $keySectionSize = $kByteSize; + |final long $valueSectionSize = $vByteSize; + |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeMapData $unsafeMapData = new UnsafeMapData(); + | Platform.putLong($data, $baseOffset, $keySectionSize); + | Platform.putLong($data, ${baseOffset + 8}, $numEntries); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); + | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); + | ArrayData $keyArrayData = $unsafeMapData.keyArray(); + | ArrayData $valueArrayData = $unsafeMapData.valueArray(); + | $assignmentLoop + | $mapData = $unsafeMapData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val mapDataClass = classOf[ArrayBasedMapData].getName() + + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + if (dataType.valueContainsNull && isValuePrimitive) { + s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" + } else { + s"$values[$idx] = $value;" + } + } + val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment) + + s""" + |final Object[] $keys = new Object[$numEntries]; + |final Object[] $values = new Object[$numEntries]; + |$assignmentLoop + |$mapData = $mapDataClass.apply($keys, $values); + """.stripMargin + } + + override def prettyName: String = "map_from_entries" +} + + +/** + * Common base class for [[SortArray]] and [[ArraySort]]. + */ +trait ArraySortLike extends ExpectsInputTypes { + protected def arrayExpression: Expression + + protected def nullOrder: NullOrder + + @transient private lazy val lt: Comparator[Any] = { + val ordering = arrayExpression.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + } + + new Comparator[Any]() { + override def compare(o1: Any, o2: Any): Int = { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + nullOrder + } else if (o2 == null) { + -nullOrder + } else { + ordering.compare(o1, o2) + } + } + } + } + + @transient private lazy val gt: Comparator[Any] = { + val ordering = arrayExpression.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + } + + new Comparator[Any]() { + override def compare(o1: Any, o2: Any): Int = { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + -nullOrder + } else if (o2 == null) { + nullOrder + } else { + ordering.compare(o2, o1) + } + } + } + } + + @transient lazy val elementType: DataType = + arrayExpression.dataType.asInstanceOf[ArrayType].elementType + + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + + def sortEval(array: Any, ascending: Boolean): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementType != NullType) { + java.util.Arrays.sort(data, if (ascending) lt else gt) + } + new GenericArrayData(data.asInstanceOf[Array[Any]]) + } + + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { + val arrayData = classOf[ArrayData].getName + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val array = ctx.freshName("array") + val c = ctx.freshName("c") + if (elementType == NullType) { + s"${ev.value} = $base.copy();" + } else { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + val sortOrder = ctx.freshName("sortOrder") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val jt = CodeGenerator.javaType(elementType) + val comp = if (CodeGenerator.isPrimitiveType(elementType)) { + val bt = CodeGenerator.boxedType(elementType) + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$jt $v1 = (($bt) $o1).${jt}Value(); + |$jt $v2 = (($bt) $o2).${jt}Value(); + |int $c = ${ctx.genComp(elementType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" + } + val nonNullPrimitiveAscendingSort = + if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val javaType = CodeGenerator.javaType(elementType) + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($order) { + | $javaType[] $array = $base.to${primitiveTypeName}Array(); + | java.util.Arrays.sort($array); + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); + |} else + """.stripMargin + } else { + "" + } + s""" + |$nonNullPrimitiveAscendingSort + |{ + | Object[] $array = $base.toObjectArray($elementTypeTerm); + | final int $sortOrder = $order ? 1 : -1; + | java.util.Arrays.sort($array, new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | if ($o1 == null && $o2 == null) { + | return 0; + | } else if ($o1 == null) { + | return $sortOrder * $nullOrder; + | } else if ($o2 == null) { + | return -$sortOrder * $nullOrder; + | } + | $comp + | return $sortOrder * $c; + | } + | }); + | ${ev.value} = new $genericArrayData($array); + |} + """.stripMargin + } + } + +} + +object ArraySortLike { + type NullOrder = Int + // Least: place null element at the first of the array for ascending order + // Greatest: place null element at the end of the array for ascending order + object NullOrder { + val Least: NullOrder = -1 + val Greatest: NullOrder = 1 + } +} + +/** + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. Null elements will be placed + at the beginning of the returned array in ascending order or at the end of the returned + array in descending order. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); + [null,"a","b","c","d"] + """) +// scalastyle:on line.size.limit +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ArraySortLike { + + def this(e: Expression) = this(e, Literal(true)) + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def arrayExpression: Expression = base + override def nullOrder: NullOrder = NullOrder.Least + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } + case ArrayType(dt, _) => + val dtSimple = dt.catalogString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + sortEval(array, ascending.asInstanceOf[Boolean]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + } + + override def prettyName: String = "sort_array" +} + + +/** + * Sorts the input array in ascending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + ["a","b","c","d",null] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { + + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def arrayExpression: Expression = child + override def nullOrder: NullOrder = NullOrder.Greatest + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + TypeCheckResult.TypeCheckSuccess + case ArrayType(dt, _) => + val dtSimple = dt.catalogString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any): Any = { + sortEval(array, true) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + } + + override def prettyName: String = "array_sort" +} + +/** + * Returns a random permutation of the given array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a random permutation of the given array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, 3, 5)); + [3, 1, 5, 20] + > SELECT _FUNC_(array(1, 20, null, 3)); + [20, null, 3, 1] + """, + note = "The function is non-deterministic.", + since = "2.4.0") +case class Shuffle(child: Expression, randomSeed: Option[Long] = None) + extends UnaryExpression with ExpectsInputTypes with Stateful with ExpressionWithRandomSeed { + + def this(child: Expression) = this(child, None) + + override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) + + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private[this] var random: RandomIndicesGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = { + random = RandomIndicesGenerator(randomSeed.get + partitionIndex) + } + + override protected def evalInternal(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + val source = value.asInstanceOf[ArrayData] + val numElements = source.numElements() + val indices = random.getNextIndices(numElements) + new GenericArrayData(indices.map(source.get(_, elementType))) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => shuffleArrayCodeGen(ctx, ev, c)) + } + + private def shuffleArrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val randomClass = classOf[RandomIndicesGenerator].getName + + val rand = ctx.addMutableState(randomClass, "rand", forceInline = true) + ctx.addPartitionInitializationStatement( + s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);") + + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + + val initialization = if (isPrimitiveType) { + ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") + } else { + val arrayDataClass = classOf[GenericArrayData].getName() + s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" + } + + val indices = ctx.freshName("indices") + val i = ctx.freshName("i") + + val getValue = CodeGenerator.getValue(childName, elementType, s"$indices[$i]") + + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + + val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($childName.isNullAt($indices[$i])) { + | $arrayData.setNullAt($i); + |} else { + | $arrayData.$setFunc($i, $getValue); + |} + """.stripMargin + } else { + s"$arrayData.$setFunc($i, $getValue);" + } + + s""" + |int $numElements = $childName.numElements(); + |int[] $indices = $rand.getNextIndices($numElements); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | $assignment + |} + |${ev.value} = $arrayData; + """.stripMargin + } + + override def freshCopy(): Shuffle = Shuffle(child, randomSeed) +} + +/** + * Returns a reversed string or an array with reverse order of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + LQS krapS + > SELECT _FUNC_(array(2, 1, 4, 3)); + [3, 4, 1, 2] + """, + since = "1.5.0", + note = "Reverse logic for arrays is available since 2.4.0." +) +case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Input types are utilized by type coercion in ImplicitTypeCasts. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + + override def dataType: DataType = child.dataType + + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(input: Any): Any = input match { + case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) + case s: UTF8String => s.reverse() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => dataType match { + case _: StringType => stringCodeGen(ev, c) + case _: ArrayType => arrayCodeGen(ctx, ev, c) + }) + } + + private def stringCodeGen(ev: ExprCode, childName: String): String = { + s"${ev.value} = ($childName).reverse();" + } + + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + + val initialization = if (isPrimitiveType) { + ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") + } else { + val arrayDataClass = classOf[GenericArrayData].getName + s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" + } + + val i = ctx.freshName("i") + val j = ctx.freshName("j") + + val getValue = CodeGenerator.getValue(childName, elementType, i) + + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + + val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($childName.isNullAt($i)) { + | $arrayData.setNullAt($j); + |} else { + | $arrayData.$setFunc($j, $getValue); + |} + """.stripMargin + } else { + s"$arrayData.$setFunc($j, $getValue);" + } + + s""" + |final int $numElements = $childName.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | int $j = $numElements - $i - 1; + | $assignment + |} + |${ev.value} = $arrayData; + """.stripMargin + } + + override def prettyName: String = "reverse" +} + +/** + * Checks if the array (left) has the element (right) + */ +@ExpressionDescription( + usage = "_FUNC_(array, value) - Returns true if the array contains the value.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + true + """) +case class ArrayContains(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq.empty + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || !left.dataType.asInstanceOf[ArrayType].elementType.sameType(right.dataType)) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + var hasNull = false + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (ordering.equiv(v, value)) { + return true + } + ) + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" for (int $i = 0; $i < $arr.numElements(); $i ++) { if ($arr.isNullAt($i)) { ${ev.isNull} = true; @@ -371,515 +1474,2848 @@ case class ArrayContains(left: Expression, right: Expression) break; } } - """ + """ + }) + } + + override def prettyName: String = "array_contains" +} + +/** + * Checks if the two arrays contain at least one common element. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); + true + """, since = "2.4.0") +// scalastyle:off line.size.limit +case class ArraysOverlap(left: Expression, right: Expression) + extends BinaryArrayExpressionWithImplicitCast { + + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + case failure => failure + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { + fastEval _ + } else { + bruteForceEval _ + } + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(a1: Any, a2: Any): Any = { + doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData]) + } + + /** + * A fast implementation which puts all the elements from the smaller array in a set + * and then performs a lookup on it for each element of the bigger one. + * This eval mode works only for data types which implements properly the equals method. + */ + private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2) + } else { + (arr2, arr1) + } + if (smaller.numElements() > 0) { + val smallestSet = new mutable.HashSet[Any] + smaller.foreach(elementType, (_, v) => + if (v == null) { + hasNull = true + } else { + smallestSet += v + }) + bigger.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else if (smallestSet.contains(v1)) { + return true + } + ) + } + if (hasNull) { + null + } else { + false + } + } + + /** + * A slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + if (arr1.numElements() > 0 && arr2.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v2 == null) { + hasNull = true + } else if (ordering.equiv(v1, v2)) { + return true + } + ) + }) + } + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (a1, a2) => { + val smaller = ctx.freshName("smallerArray") + val bigger = ctx.freshName("biggerArray") + val comparisonCode = if (TypeUtils.typeWithProperEquals(elementType)) { + fastCodegen(ctx, ev, smaller, bigger) + } else { + bruteForceCodegen(ctx, ev, smaller, bigger) + } + s""" + |ArrayData $smaller; + |ArrayData $bigger; + |if ($a1.numElements() > $a2.numElements()) { + | $bigger = $a1; + | $smaller = $a2; + |} else { + | $smaller = $a1; + | $bigger = $a2; + |} + |if ($smaller.numElements() > 0) { + | $comparisonCode + |} + """.stripMargin + }) + } + + /** + * Code generation for a fast implementation which puts all the elements from the smaller array + * in a set and then performs a lookup on it for each element of the bigger one. + * It works only for data types which implements properly the equals method. + */ + private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val javaElementClass = CodeGenerator.boxedType(elementType) + val javaSet = classOf[java.util.HashSet[_]].getName + val set = ctx.freshName("set") + val addToSetFromSmallerCode = nullSafeElementCodegen( + smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen( + bigger, + i, + s""" + |if ($set.contains($getFromBigger)) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); + |for (int $i = 0; $i < $smaller.numElements(); $i ++) { + | $addToSetFromSmallerCode + |} + |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + | $elementIsInSetCode + |} + """.stripMargin + } + + /** + * Code generation for a slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val compareValues = nullSafeElementCodegen( + smaller, + j, + s""" + |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + |} + """.stripMargin, + s"${ev.isNull} = true;") + val isInSmaller = nullSafeElementCodegen( + bigger, + i, + s""" + |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) { + | $compareValues + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) { + | $isInSmaller + |} + """.stripMargin + } + + def nullSafeElementCodegen( + arrayVar: String, + index: String, + code: String, + isNullCode: String): String = { + if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { + s""" + |if ($arrayVar.isNullAt($index)) { + | $isNullCode + |} else { + | $code + |} + """.stripMargin + } else { + code + } + } + + override def prettyName: String = "arrays_overlap" +} + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval + + @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { + val startInt = startVal.asInstanceOf[Int] + val lengthInt = lengthVal.asInstanceOf[Int] + val arr = xVal.asInstanceOf[ArrayData] + val startIndex = if (startInt == 0) { + throw new RuntimeException( + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + } else if (startInt < 0) { + startInt + arr.numElements() + } else { + startInt - 1 + } + if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + + "length must be greater than or equal to 0.") + } + // startIndex can be negative if start is negative and its absolute value is greater than the + // number of elements in the array + if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) + } + val data = arr.toSeq[AnyRef](elementType) + new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | + "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | + "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + | $values[$i] = $getValue; + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $resLength = 0; + |} + |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} + |for (int $i = 0; $i < $resLength; $i ++) { + | if ($inputArray.isNullAt($i + $startIdx)) { + | $values.setNullAt($i); + | } else { + | $values.set$primitiveValueTypeName($i, $getValue); + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } +} + +/** + * Creates a String containing all the elements of the input array separated by the delimiter. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array + using the delimiter and an optional string to replace nulls. If no value is set for + nullReplacement, any null value is filtered.""", + examples = """ + Examples: + > SELECT _FUNC_(array('hello', 'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); + hello , world + """, since = "2.4.0") +case class ArrayJoin( + array: Expression, + delimiter: Expression, + nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { + + def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) + + def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = + this(array, delimiter, Some(nullReplacement)) + + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { + Seq(ArrayType(StringType), StringType, StringType) + } else { + Seq(ArrayType(StringType), StringType) + } + + override def children: Seq[Expression] = if (nullReplacement.isDefined) { + Seq(array, delimiter, nullReplacement.get) + } else { + Seq(array, delimiter) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val arrayEval = array.eval(input) + if (arrayEval == null) return null + val delimiterEval = delimiter.eval(input) + if (delimiterEval == null) return null + val nullReplacementEval = nullReplacement.map(_.eval(input)) + if (nullReplacementEval.contains(null)) return null + + val buffer = new UTF8StringBuilder() + var firstItem = true + val nullHandling = nullReplacementEval match { + case Some(rep) => (prependDelimiter: Boolean) => { + if (!prependDelimiter) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(rep.asInstanceOf[UTF8String]) + true + } + case None => (_: Boolean) => false + } + arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { + if (item == null) { + if (nullHandling(firstItem)) { + firstItem = false + } + } else { + if (!firstItem) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(item.asInstanceOf[UTF8String]) + firstItem = false + } + }) + buffer.build() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val code = nullReplacement match { + case Some(replacement) => + val replacementGen = replacement.genCode(ctx) + val nullHandling = (buffer: String, delimiter: String, firstItem: String) => { + s""" + |if (!$firstItem) { + | $buffer.append($delimiter); + |} + |$buffer.append(${replacementGen.value}); + |$firstItem = false; + """.stripMargin + } + val execCode = if (replacement.nullable) { + ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + } else { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + s""" + |${replacementGen.code} + |$execCode + """.stripMargin + case None => genCodeForArrayAndDelimiter(ctx, ev, + (_: String, _: String, _: String) => "// nulls are ignored") + } + if (nullable) { + ev.copy( + code""" + |boolean ${ev.isNull} = true; + |UTF8String ${ev.value} = null; + |$code + """.stripMargin) + } else { + ev.copy( + code""" + |UTF8String ${ev.value} = null; + |$code + """.stripMargin, FalseLiteral) + } + } + + private def genCodeForArrayAndDelimiter( + ctx: CodegenContext, + ev: ExprCode, + nullEval: (String, String, String) => String): String = { + val arrayGen = array.genCode(ctx) + val delimiterGen = delimiter.genCode(ctx) + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val i = ctx.freshName("i") + val firstItem = ctx.freshName("firstItem") + val resultCode = + s""" + |$bufferClass $buffer = new $bufferClass(); + |boolean $firstItem = true; + |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) { + | if (${arrayGen.value}.isNullAt($i)) { + | ${nullEval(buffer, delimiterGen.value, firstItem)} + | } else { + | if (!$firstItem) { + | $buffer.append(${delimiterGen.value}); + | } + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); + | $firstItem = false; + | } + |} + |${ev.value} = $buffer.build();""".stripMargin + + if (array.nullable || delimiter.nullable) { + arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) { + delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode""".stripMargin + } + } + } else { + s""" + |${arrayGen.code} + |${delimiterGen.code} + |$resultCode""".stripMargin + } + } + + override def dataType: DataType = StringType + + override def prettyName: String = "array_join" +} + +/** + * Returns the minimum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 1 + """, since = "2.4.0") +case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode(EmptyBlock, + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + code""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfSmaller(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var min: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (min == null || ordering.lt(item, min))) { + min = item + } + ) + min + } + + @transient override lazy val dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_min" +} + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode(EmptyBlock, + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + code""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + @transient override lazy val dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} + + +/** + * Returns the position of the first occurrence of element in the given array as long. + * Returns 0 if the given value could not be found in the array. Returns null if either of + * the arguments are null + * + * NOTE: that this is not zero based, but 1-based index. The first element in the array has + * index 1. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(3, 2, 1), 1); + 3 + """, + since = "2.4.0") +case class ArrayPosition(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def dataType: DataType = LongType + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v != null && ordering.equiv(v, value)) { + return (i + 1).toLong + } + ) + 0L + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val pos = ctx.freshName("arrayPosition") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int $pos = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $pos = $i + 1; + | break; + | } + |} + |${ev.value} = (long) $pos; + """.stripMargin + }) + } +} + +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType + + @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull + + @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType) + + @transient override lazy val dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => mapKeyType + case _ => AnyDataType // no match for a wrong 'left' expression type + } + ) + } + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => + TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName") + case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (arrayContainsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, mapKeyType, ordering) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (arrayContainsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { + + private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been ${StringType.simpleString}," + + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + + childTypes.map(_.catalogString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + } + } + + @transient override lazy val dataType: DataType = { + if (children.isEmpty) { + StringType + } else { + super.dataType + } + } + + private def javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { + case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val args = ctx.freshName("args") + val hasNull = ctx.freshName("hasNull") + + val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map { + case ((eval, true), index) => + s""" + |if (!$hasNull) { + | ${eval.code} + | if (!${eval.isNull}) { + | $args[$index] = ${eval.value}; + | } else { + | $hasNull = true; + | } + |} + """.stripMargin + case ((eval, false), index) => + s""" + |if (!$hasNull) { + | ${eval.code} + | $args[$index] = ${eval.value}; + |} + """.stripMargin + } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNull; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n") + ) + + val (concat, initCode) = dataType match { + case BinaryType => + (s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, containsNull) => + val concat = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType, containsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, elementType) + } + (concat, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + + ev.copy(code = + code""" + |boolean $hasNull = false; + |$initCode + |$codes + |$javaType ${ev.value} = null; + |if (!$hasNull) { + | ${ev.value} = $concat($args); + |} + |boolean ${ev.isNull} = ${ev.value} == null; + """.stripMargin) + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + val numElements = ctx.freshName("numElements") + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + |} + """.stripMargin + + (code, numElements) + } + + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + val setterCode = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + |); + """.stripMargin + + val nullSafeSetterCode = if (checkForNull) { + s""" + |if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode + |} + """.stripMargin + } else { + setterCode + } + + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] args) { + | $numElemCode + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $nullSafeSetterCode + | $counter++; + | } + | } + | return $arrayData; + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] args) { + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) + } + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" +} + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """, + since = "2.4.0") +case class Flatten(child: Expression) extends UnaryExpression { + + private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def nullable: Boolean = child.nullable || childDataType.containsNull + + @transient override lazy val dataType: DataType = childDataType.elementType + + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.catalogString} type." + ) + } + + override def nullSafeEval(child: Any): Any = { + val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val arrayData = elements.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + s"$numberOfElements elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") + } + val flattenedData = new Array(numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, flattenedData, position, arr.length) + position += arr.length + } + new GenericArrayData(flattenedData) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) + } + ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } - override def prettyName: String = "array_contains" + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = s""" + |long $variableName = 0; + |for (int z = 0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $variableName + " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForFlattenOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | if (arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue("arr", elementType, "l")} + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForFlattenOfNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "flatten" +} + +@ExpressionDescription( + usage = """ + _FUNC_(start, stop, step) - Generates an array of elements from start to stop (inclusive), + incrementing by step. The type of the returned elements is the same as the type of argument + expressions. + + Supported types are: byte, short, integer, long, date, timestamp. + + The start and stop expressions must resolve to the same type. + If start and stop expressions resolve to the 'date' or 'timestamp' type + then the step expression must resolve to the 'interval' type, otherwise to the same type + as the start and stop expressions. + """, + arguments = """ + Arguments: + * start - an expression. The start of the range. + * stop - an expression. The end the range (inclusive). + * step - an optional expression. The step of the range. + By default step is 1 if start is less than or equal to stop, otherwise -1. + For the temporal sequences it's 1 day and -1 day respectively. + If start is greater than stop then the step must be negative, and vice versa. + """, + examples = """ + Examples: + > SELECT _FUNC_(1, 5); + [1, 2, 3, 4, 5] + > SELECT _FUNC_(5, 1); + [5, 4, 3, 2, 1] + > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month); + [2018-01-01, 2018-02-01, 2018-03-01] + """, + since = "2.4.0" +) +case class Sequence( + start: Expression, + stop: Expression, + stepOpt: Option[Expression], + timeZoneId: Option[String] = None) + extends Expression + with TimeZoneAwareExpression { + + import Sequence._ + + def this(start: Expression, stop: Expression) = + this(start, stop, None, None) + + def this(start: Expression, stop: Expression, step: Expression) = + this(start, stop, Some(step), None) + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Some(timeZoneId)) + + override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + + override def checkInputDataTypes(): TypeCheckResult = { + val startType = start.dataType + def stepType = stepOpt.get.dataType + val typesCorrect = + startType.sameType(stop.dataType) && + (startType match { + case TimestampType | DateType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) + case _: IntegralType => + stepOpt.isEmpty || stepType.sameType(startType) + case _ => false + }) + + if (typesCorrect) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName only supports integral, timestamp or date types") + } + } + + def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType) + + def castChildrenTo(widerType: DataType): Expression = Sequence( + Cast(start, widerType), + Cast(stop, widerType), + stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), + timeZoneId) + + @transient private lazy val impl: SequenceImpl = dataType.elementType match { + case iType: IntegralType => + type T = iType.InternalType + val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) + new IntegralSequenceImpl(iType)(ct, iType.integral) + + case TimestampType => + new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone) + + case DateType => + new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone) + } + + override def eval(input: InternalRow): Any = { + val startVal = start.eval(input) + if (startVal == null) return null + val stopVal = stop.eval(input) + if (stopVal == null) return null + val stepVal = stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal)) + if (stepVal == null) return null + + ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val startGen = start.genCode(ctx) + val stopGen = stop.genCode(ctx) + val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse( + impl.defaultStep.genCode(ctx, startGen, stopGen)) + + val resultType = CodeGenerator.javaType(dataType) + val resultCode = { + val arr = ctx.freshName("arr") + val arrElemType = CodeGenerator.javaType(dataType.elementType) + s""" + |final $arrElemType[] $arr = null; + |${impl.genCode(ctx, startGen.value, stopGen.value, stepGen.value, arr, arrElemType)} + |${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr); + """.stripMargin + } + + if (nullable) { + val nullSafeEval = + startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) { + stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) { + stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable), stepGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode + """.stripMargin + } + } + } + ev.copy(code = + code""" + |boolean ${ev.isNull} = true; + |$resultType ${ev.value} = null; + |$nullSafeEval + """.stripMargin) + + } else { + ev.copy(code = + code""" + |${startGen.code} + |${stopGen.code} + |${stepGen.code} + |$resultType ${ev.value} = null; + |$resultCode + """.stripMargin, + isNull = FalseLiteral) + } + } +} + +object Sequence { + + private type LessThanOrEqualFn = (Any, Any) => Boolean + + private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType, one: Any) { + private val negativeOne = UnaryMinus(Literal(one)).eval() + + def apply(start: Any, stop: Any): Any = { + if (lteq(start, stop)) one else negativeOne + } + + def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen: ExprCode): ExprCode = { + val Seq(oneVal, negativeOneVal) = Seq(one, negativeOne).map(Literal(_).genCode(ctx).value) + ExprCode.forNonNullValue(JavaCode.expression( + s"${startGen.value} <= ${stopGen.value} ? $oneVal : $negativeOneVal", + stepType)) + } + } + + private trait SequenceImpl { + def eval(start: Any, stop: Any, step: Any): Any + + def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String + + val defaultStep: DefaultStep + } + + private class IntegralSequenceImpl[T: ClassTag] + (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + elemType, + num.one) + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + import num._ + + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[T] + + var i: Int = getSequenceLength(start, stop, step) + val arr = new Array[T](i) + while (i > 0) { + i -= 1 + arr(i) = start + step * num.fromInt(i) + } + arr + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val i = ctx.freshName("i") + s""" + |${genSequenceLengthCode(ctx, start, stop, step, i)} + |$arr = new $elemType[$i]; + |while ($i > 0) { + | $i--; + | $arr[$i] = ($elemType) ($start + $step * $i); + |} + """.stripMargin + } + } + + private class TemporalSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone) + (implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + CalendarIntervalType, + new CalendarInterval(0, MICROS_PER_DAY)) + + private val backedSequenceImpl = new IntegralSequenceImpl[T](dt) + private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[CalendarInterval] + val stepMonths = step.months + val stepMicros = step.microseconds + + if (stepMonths == 0) { + backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale)) + + } else { + // To estimate the resulted array length we need to make assumptions + // about a month length in microseconds + val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth + val startMicros: Long = num.toLong(start) * scale + val stopMicros: Long = num.toLong(stop) * scale + val maxEstimatedArrayLength = + getSequenceLength(startMicros, stopMicros, intervalStepInMicros) + + val stepSign = if (stopMicros > startMicros) +1 else -1 + val exclusiveItem = stopMicros + stepSign + val arr = new Array[T](maxEstimatedArrayLength) + var t = startMicros + var i = 0 + + while (t < exclusiveItem ^ stepSign < 0) { + arr(i) = fromLong(t / scale) + t = timestampAddInterval(t, stepMonths, stepMicros, timeZone) + i += 1 + } + + // truncate array to the correct length + if (arr.length == i) arr else arr.slice(0, i) + } + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val stepMonths = ctx.freshName("stepMonths") + val stepMicros = ctx.freshName("stepMicros") + val stepScaled = ctx.freshName("stepScaled") + val intervalInMicros = ctx.freshName("intervalInMicros") + val startMicros = ctx.freshName("startMicros") + val stopMicros = ctx.freshName("stopMicros") + val arrLength = ctx.freshName("arrLength") + val stepSign = ctx.freshName("stepSign") + val exclusiveItem = ctx.freshName("exclusiveItem") + val t = ctx.freshName("t") + val i = ctx.freshName("i") + val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName) + + val sequenceLengthCode = + s""" + |final long $intervalInMicros = $stepMicros + $stepMonths * ${microsPerMonth}L; + |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} + """.stripMargin + + val timestampAddIntervalCode = + s""" + |$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval( + | $t, $stepMonths, $stepMicros, $genTimeZone); + """.stripMargin + + s""" + |final int $stepMonths = $step.months; + |final long $stepMicros = $step.microseconds; + | + |if ($stepMonths == 0) { + | final $elemType $stepScaled = ($elemType) ($stepMicros / ${scale}L); + | ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled, arr, elemType)}; + | + |} else { + | final long $startMicros = $start * ${scale}L; + | final long $stopMicros = $stop * ${scale}L; + | + | $sequenceLengthCode + | + | final int $stepSign = $stopMicros > $startMicros ? +1 : -1; + | final long $exclusiveItem = $stopMicros + $stepSign; + | + | $arr = new $elemType[$arrLength]; + | long $t = $startMicros; + | int $i = 0; + | + | while ($t < $exclusiveItem ^ $stepSign < 0) { + | $arr[$i] = ($elemType) ($t / ${scale}L); + | $timestampAddIntervalCode + | $i += 1; + | } + | + | if ($arr.length > $i) { + | $arr = java.util.Arrays.copyOf($arr, $i); + | } + |} + """.stripMargin + } + } + + private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = { + import num._ + require( + (step > num.zero && start <= stop) + || (step < num.zero && start >= stop) + || (step == num.zero && start == stop), + s"Illegal sequence boundaries: $start to $stop by $step") + + val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong + + require( + len <= MAX_ROUNDED_ARRAY_LENGTH, + s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + len.toInt + } + + private def genSequenceLengthCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + len: String): String = { + val longLen = ctx.freshName("longLen") + s""" + |if (!(($step > 0 && $start <= $stop) || + | ($step < 0 && $start >= $stop) || + | ($step == 0 && $start == $stop))) { + | throw new IllegalArgumentException( + | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); + |} + |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step; + |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { + | throw new IllegalArgumentException( + | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); + |} + |int $len = (int) $longLen; + """.stripMargin + } +} + +/** + * Returns the array containing the given input value (left) count (right) times. + */ +@ExpressionDescription( + usage = "_FUNC_(element, count) - Returns the array containing element count times.", + examples = """ + Examples: + > SELECT _FUNC_('123', 2); + ['123', '123'] + """, + since = "2.4.0") +case class ArrayRepeat(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def nullable: Boolean = right.nullable + + override def eval(input: InternalRow): Any = { + val count = right.eval(input) + if (count == null) { + null + } else { + if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + + s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + } + val element = left.eval(input) + new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) + } + } + + override def prettyName: String = "array_repeat" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val element = leftGen.value + val count = rightGen.value + val et = dataType.elementType + + val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { + genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) + } else { + genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) + } + val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) + + ev.copy(code = + code""" + |boolean ${ev.isNull} = false; + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = + | ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin) + } + + private def nullElementsProtection( + ev: ExprCode, + rightIsNull: String, + coreLogic: String): String = { + if (nullable) { + s""" + |if ($rightIsNull) { + | ${ev.isNull} = true; + |} else { + | ${coreLogic} + |} + """.stripMargin + } else { + coreLogic + } + } + + private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = { + val numElements = ctx.freshName("numElements") + val numElementsCode = + s""" + |int $numElements = 0; + |if ($count > 0) { + | $numElements = $count; + |} + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + |} + """.stripMargin + + (numElements, numElementsCode) + } + + private def genCodeForPrimitiveElement( + ctx: CodegenContext, + elementType: DataType, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val tempArrayDataName = ctx.freshName("tempArrayData") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val errorMessage = s" $prettyName failed." + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |if (!$leftIsNull) { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | } + |} else { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.setNullAt(k); + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForNonPrimitiveElement( + ctx: CodegenContext, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |if (!$leftIsNull) { + | for (int k = 0; k < $numElemName; k++) { + | $arrayName[k] = $element; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + } /** - * Returns the minimum value in the array. + * Remove all elements that equal to element from the given array */ @ExpressionDescription( - usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", examples = """ Examples: - > SELECT _FUNC_(array(1, 20, null, 3)); - 1 + > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); + [1,2,null] """, since = "2.4.0") -case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayRemove(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { - override def nullable: Boolean = true + override def dataType: DataType = left.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) override def checkInputDataTypes(): TypeCheckResult = { - val typeCheckResult = super.checkInputDataTypes() - if (typeCheckResult.isSuccess) { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") - } else { - typeCheckResult + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - val javaType = CodeGenerator.javaType(dataType) - val i = ctx.freshName("i") - val item = ExprCode("", - isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), - value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) - ev.copy(code = + override def nullSafeEval(arr: Any, value: Any): Any = { + val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) + var pos = 0 + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null || !ordering.equiv(v, value)) { + newArray(pos) = v + pos += 1 + } + ) + new GenericArrayData(newArray.slice(0, pos)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val numsToRemove = ctx.freshName("numsToRemove") + val newArraySize = ctx.freshName("newArraySize") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) s""" - |${childGen.code} - |boolean ${ev.isNull} = true; - |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${childGen.isNull}) { - | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { - | ${ctx.reassignIfSmaller(dataType, ev, item)} + |int $numsToRemove = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && $isEqual) { + | $numsToRemove = $numsToRemove + 1; | } |} - """.stripMargin) - } - - override protected def nullSafeEval(input: Any): Any = { - var min: Any = null - input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => - if (item != null && (min == null || ordering.lt(item, min))) { - min = item - } - ) - min + |int $newArraySize = $arr.numElements() - $numsToRemove; + |${genCodeForResult(ctx, ev, arr, value, newArraySize)} + """.stripMargin + }) } - override def dataType: DataType = child.dataType match { - case ArrayType(dt, _) => dt - case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + value: String, + newArraySize: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val getValue = CodeGenerator.getValue(inputArray, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |int $pos = 0; + |Object[] $values = new Object[$newArraySize]; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values[$pos] = null; + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values[$pos] = $getValue; + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values.set$primitiveValueTypeName($pos, $getValue); + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin + } } - override def prettyName: String = "array_min" + override def prettyName: String = "array_remove" } /** - * Returns the maximum value in the array. + * Removes duplicate values from the array. */ @ExpressionDescription( - usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + usage = "_FUNC_(array) - Removes duplicate values from the array.", examples = """ Examples: - > SELECT _FUNC_(array(1, 20, null, 3)); - 20 + > SELECT _FUNC_(array(1, 2, 3, null, 3)); + [1,2,3,null] """, since = "2.4.0") -case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def nullable: Boolean = true +case class ArrayDistinct(child: Expression) + extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + override def dataType: DataType = child.dataType + + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) override def checkInputDataTypes(): TypeCheckResult = { - val typeCheckResult = super.checkInputDataTypes() - if (typeCheckResult.isSuccess) { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") - } else { - typeCheckResult + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - val javaType = CodeGenerator.javaType(dataType) - val i = ctx.freshName("i") - val item = ExprCode("", - isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), - value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) - ev.copy(code = - s""" - |${childGen.code} - |boolean ${ev.isNull} = true; - |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${childGen.isNull}) { - | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { - | ${ctx.reassignIfGreater(dataType, ev, item)} - | } - |} - """.stripMargin) + @transient protected lazy val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false } - override protected def nullSafeEval(input: Any): Any = { - var max: Any = null - input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => - if (item != null && (max == null || ordering.gt(item, max))) { - max = item + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" + } + + override def nullSafeEval(array: Any): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + doEvaluation(data) + } + + @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { + (data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + (data: Array[AnyRef]) => { + var foundNullElement = false + var pos = 0 + for (i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } + } else { + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j - 1) { + pos = pos + 1 + } + } } - ) - max + new GenericArrayData(data.slice(0, pos)) + } } - override def dataType: DataType = child.dataType match { - case ArrayType(dt, _) => dt - case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (canUseSpecializedHashSet) { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val getValue = CodeGenerator.getValue(array, elementType, i) + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add$hsPostFix($hsValueCast$getValue); + | } + |} + |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (array) => { + val expr = ctx.addReferenceObj("arrayDistinctExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array);" + }) + } } - override def prettyName: String = "array_max" -} + private def setNull( + foundNullElement: String, + distinctArray: String, + pos: String): String = { + val setNullValue = s"$distinctArray.setNullAt($pos)" + s""" + |if (!($foundNullElement)) { + | $setNullValue; + | $pos = $pos + 1; + | $foundNullElement = true; + |} + """.stripMargin + } + + private def setValue( + hs: String, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + s""" + |if (!($hs.contains$hsPostFix($hsValueCast$getValue1))) { + | $hs.add$hsPostFix($hsValueCast$getValue1); + | $distinctArray.set$primitiveValueTypeName($pos, $getValue1); + | $pos = $pos + 1; + |} + """.stripMargin + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + size: String): String = { + val distinctArray = ctx.freshName("distinctArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) + val foundNullElement = ctx.freshName("foundNullElement") + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | ${setNull(foundNullElement, distinctArray, pos)} + | } else { + | ${setValue(hs, distinctArray, pos, getValue1, primitiveValueTypeName)} + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin + } + override def prettyName: String = "array_distinct" +} /** - * Returns the position of the first occurrence of element in the given array as long. - * Returns 0 if the given value could not be found in the array. Returns null if either of - * the arguments are null - * - * NOTE: that this is not zero based, but 1-based index. The first element in the array has - * index 1. + * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ -@ExpressionDescription( - usage = """ - _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. - """, - examples = """ - Examples: - > SELECT _FUNC_(array(3, 2, 1), 1); - 3 - """, - since = "2.4.0") -case class ArrayPosition(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } - override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = - Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) - override def nullSafeEval(arr: Any, value: Any): Any = { - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) { - return (i + 1).toLong - } - ) - 0L + @transient protected lazy val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false } - override def prettyName: String = "array_position" + protected def genGetValue(array: String, i: String): String = + CodeGenerator.getValue(array, elementType, i) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (arr, value) => { - val pos = ctx.freshName("arrayPosition") - val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(arr, right.dataType, i) + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName (elementType) + elementType match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" + } + + // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + @transient protected lazy val nullValueHolder = elementType match { + case ByteType => "(byte) 0" + case ShortType => "(short) 0" + case _ => "0" + } + + protected def withResultArrayNullCheck( + body: String, + value: String, + nullElementIndex: String): String = { + if (dataType.asInstanceOf[ArrayType].containsNull) { s""" - |int $pos = 0; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { - | $pos = $i + 1; - | break; - | } + |$body + |if ($nullElementIndex >= 0) { + | // result has null element + | $value.setNullAt($nullElementIndex); |} - |${ev.value} = (long) $pos; """.stripMargin - }) + } else { + body + } + } + + def buildResultArray( + builder: String, + value : String, + size : String, + nullElementIndex : String): String = withResultArrayNullCheck( + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Cannot create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); + |} + | + |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { + | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | $value = new ${classOf[GenericArrayData].getName}($builder.result()); + |} + """.stripMargin, value, nullElementIndex) +} + +object ArraySetLike { + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } } + /** - * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + * Returns an array of the elements in the union of x and y, without duplicates */ @ExpressionDescription( usage = """ - _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, - accesses elements from the last to the first. Returns NULL if the index exceeds the length - of the array. - - _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. """, examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 2); - 2 - > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); - "b" + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) """, since = "2.4.0") -case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { - - override def dataType: DataType = left.dataType match { - case ArrayType(elementType, _) => elementType - case MapType(_, valueType, _) => valueType - } - - override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(ArrayType, MapType), - left.dataType match { - case _: ArrayType => IntegerType - case _: MapType => left.dataType.asInstanceOf[MapType].keyType - } - ) - } - - override def nullable: Boolean = true +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { - override def nullSafeEval(value: Any, ordinal: Any): Any = { - left.dataType match { - case _: ArrayType => - val array = value.asInstanceOf[ArrayData] - val index = ordinal.asInstanceOf[Int] - if (array.numElements() < math.abs(index)) { - null - } else { - val idx = if (index == 0) { - throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") - } else if (index > 0) { - index - 1 - } else { - array.numElements() + index + @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { + if (TypeUtils.typeWithProperEquals(elementType)) { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 } - if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { - null + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } } else { - array.get(idx, dataType) + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } } - } - case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) } } + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalUnion(array1, array2) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - left.dataType match { - case _: ArrayType => - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val index = ctx.freshName("elementAtIndex") - val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val array = ctx.freshName("array") + val arrays = ctx.freshName("arrays") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + def withArrayNullAssignment(body: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { s""" - |if ($eval1.isNullAt($index)) { - | ${ev.isNull} = true; - |} else + |if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = true; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} """.stripMargin } else { - "" + body } + + val processArray = withArrayNullAssignment( s""" - |int $index = (int) $eval2; - |if ($eval1.numElements() < Math.abs($index)) { - | ${ev.isNull} = true; - |} else { - | if ($index == 0) { - | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); - | } else if ($index > 0) { - | $index--; - | } else { - | $index += $eval1.numElements(); - | } - | $nullCheck - | { - | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + |$jt $value = ${genGetValue(array, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; """.stripMargin - }) - case _: MapType => - doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } else { + "" + } + + s""" + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |int $size = 0; + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |ArrayData[] $arrays = new ArrayData[]{$array1, $array2}; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | ArrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | $processArray + | } + |} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayUnionExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" + }) } } - override def prettyName: String = "element_at" + override def prettyName: String = "array_union" +} + +object ArrayUnion { + def unionOrdering( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } } /** - * Concatenates multiple input columns together into a single column. - * The function works with strings, binary and compatible array columns. + * Returns an array of the elements in the intersect of x and y, without duplicates */ @ExpressionDescription( - usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and + array2, without duplicates. + """, examples = """ Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); - | [1,2,3,4,5,6] - """) -case class Concat(children: Seq[Expression]) extends Expression { - - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - val allowedTypes = Seq(StringType, BinaryType, ArrayType) + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 3) + """, + since = "2.4.0") +case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + ArrayType(elementType, + left.dataType.asInstanceOf[ArrayType].containsNull && + right.dataType.asInstanceOf[ArrayType].containsNull) + } - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckSuccess + @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { + if (TypeUtils.typeWithProperEquals(elementType)) { + (array1, array2) => + if (array1.numElements() != 0 && array2.numElements() != 0) { + val hs = new OpenHashSet[Any] + val hsResult = new OpenHashSet[Any] + var foundNullElement = false + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + foundNullElement = true + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (foundNullElement) { + arrayBuffer += null + foundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (hs.contains(elem) && !hsResult.contains(elem)) { + arrayBuffer += elem + hsResult.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Array.emptyObjectArray) + } } else { - val childTypes = children.map(_.dataType) - if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { - return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + - s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) - } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + (array1, array2) => + if (array1.numElements() != 0 && array2.numElements() != 0) { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadySeenNull = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!alreadySeenNull) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + alreadySeenNull = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + if (ordering.equiv(elem1, elem2)) { + // check whether elem1 is already stored in arrayBuffer + var foundArrayBuffer = false + var k = 0 + while (!foundArrayBuffer && k < arrayBuffer.size) { + val va = arrayBuffer(k) + foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + found = !foundArrayBuffer + } + } + j += 1 + } + } + if (found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Array.emptyObjectArray) + } } } - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] - lazy val javaType: String = CodeGenerator.javaType(dataType) + evalIntersect(array1, array2) + } - override def nullable: Boolean = children.exists(_.nullable) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) - override def foldable: Boolean = children.forall(_.foldable) + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val hashSetResult = ctx.freshName("hashSetResult") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" - override def eval(input: InternalRow): Any = dataType match { - case BinaryType => - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) - case StringType => - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - case ArrayType(elementType, _) => - val inputs = children.toStream.map(_.eval(input)) - if (inputs.contains(null)) { - null - } else { - val arrayData = inputs.map(_.asInstanceOf[ArrayData]) - val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") - } - val finalData = new Array[AnyRef](numberOfElements.toInt) - var position = 0 - for(ad <- arrayData) { - val arr = ad.toObjectArray(elementType) - Array.copy(arr, 0, finalData, position, arr.length) - position += arr.length - } - new GenericArrayData(finalData) - } - } + def withArray2NullCheck(body: String): String = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $foundNullElement = true; + |} else { + | $body + |} + """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") + val writeArray2ToHashSet = withArray2NullCheck( + s""" + |$jt $value = ${genGetValue(array2, i)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ - } + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = false; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + s""" + |if (!$array1.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } - val (concatenator, initCode) = dataType match { - case BinaryType => - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - case StringType => - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - case ArrayType(elementType, _) => - val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrays(ctx, elementType) + val processArray1 = withArray1NullAssignment( + s""" + |$jt $value = ${genGetValue(array1, i)}; + |if ($hashSet.contains($hsValueCast$value) && + | !$hashSetResult.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSetResult.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin } else { - genCodeForNonPrimitiveArrays(ctx, elementType) + "" } - (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + + s""" + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$openHashSet $hashSetResult = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $writeArray2ToHashSet + |} + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $processArray1 + |} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayIntersectExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" + }) } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(s""" - $initCode - $codes - $javaType ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) } - private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { - val numElements = ctx.freshName("numElements") - val code = s""" - |long $numElements = 0L; - |for (int z = 0; z < ${children.length}; z++) { - | $numElements += args[z].numElements(); - |} - |if ($numElements > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); - |} - """.stripMargin + override def prettyName: String = "array_intersect" +} - (code, numElements) +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(2) + """, + since = "2.4.0") +case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + + override def dataType: DataType = { + dataTypeCheck + left.dataType } - private def nullArgumentProtection() : String = { - if (nullable) { - s""" - |for (int z = 0; z < ${children.length}; z++) { - | if (args[z] == null) return null; - |} - """.stripMargin + @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { + if (TypeUtils.typeWithProperEquals(elementType)) { + (array1, array2) => + val hs = new OpenHashSet[Any] + var notFoundNullElement = true + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + notFoundNullElement = false + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (notFoundNullElement) { + arrayBuffer += null + notFoundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (!hs.contains(elem)) { + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) } else { - "" + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var scannedNullElements = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (elem1 == null) { + if (!scannedNullElements) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + scannedNullElements = true + } else { + found = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + val elem2 = array2.get(j, elementType) + if (elem2 != null) { + found = ordering.equiv(elem1, elem2) + } + j += 1 + } + if (!found) { + // check whether elem1 is already stored in arrayBuffer + var k = 0 + while (!found && k < arrayBuffer.size) { + val va = arrayBuffer(k) + found = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + } + } + if (!found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) } } - private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") - val counter = ctx.freshName("counter") - val arrayData = ctx.freshName("arrayData") + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + evalExcept(array1, array2) + } - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + - | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + - | " for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | $unsafeArraySizeInBytes - | byte[] $arrayName = new byte[(int)$arraySizeName]; - | UnsafeArrayData $arrayData = new UnsafeArrayData(); - | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | if (args[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - | } else { - | $arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} - | ); - | } - | $counter++; - | } - | } - | return $arrayData; - | } - |}""".stripMargin.stripPrefix("\n") - } + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val notFoundNullElement = ctx.freshName("notFoundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" - private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayData = ctx.freshName("arrayObjects") - val counter = ctx.freshName("counter") + def withArray2NullCheck(body: String): String = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else { + | $body + |} + """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + val writeArray2ToHashSet = withArray2NullCheck( + s""" + |$jt $value = ${genGetValue(array2, i)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | Object[] $arrayData = new Object[(int)$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - | } - |}""".stripMargin.stripPrefix("\n") - } + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | $nullElementIndex = $size; + | $notFoundNullElement = false; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } - override def toString: String = s"concat(${children.mkString(", ")})" + val processArray1 = withArray1NullAssignment( + s""" + |$jt $value = ${genGetValue(array1, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) - override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" + // Only need to track null element index when array1's element is nullable. + val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $notFoundNullElement = true; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + + s""" + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $writeArray2ToHashSet + |} + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $processArray1 + |} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayExceptExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" + }) + } + } + + override def prettyName: String = "array_except" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67876a8565488..077a6dc93bd17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -47,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(StringType), + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType)) + .getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -63,7 +65,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + assigns + postprocess, + code = code"${preprocess}${assigns}${postprocess}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -178,14 +180,14 @@ case class CreateMap(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure( s"$prettyName expects a positive even number of arguments.") - } else if (keys.map(_.dataType).distinct.length > 1) { + } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given keys of function map should all be the same type, but they are " + - keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (values.map(_.dataType).distinct.length > 1) { + keys.map(_.dataType.catalogString).mkString("[", ", ", "]")) + } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given values of function map should all be the same type, but they are " + - values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + values.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { TypeCheckResult.TypeCheckSuccess } @@ -193,8 +195,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(StringType), - valueType = values.headOption.map(_.dataType).getOrElse(StringType), + keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) + .getOrElse(StringType), + valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType)) + .getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } @@ -219,7 +223,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = - s""" + code""" final boolean ${ev.isNull} = false; $preprocessKeyData $assignKeys @@ -235,6 +239,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def prettyName: String = "map" } +/** + * Returns a catalyst Map containing the two arrays in children expressions as keys and values. + */ +@ExpressionDescription( + usage = """ + _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements + in keys should not be null""", + examples = """ + Examples: + > SELECT _FUNC_([1.0, 3.0], ['2', '4']); + {1.0:"2",3.0:"4"} + """, since = "2.4.0") +case class MapFromArrays(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def dataType: DataType = { + MapType( + keyType = left.dataType.asInstanceOf[ArrayType].elementType, + valueType = right.dataType.asInstanceOf[ArrayType].elementType, + valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) + } + + override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { + val keyArrayData = keyArray.asInstanceOf[ArrayData] + val valueArrayData = valueArray.asInstanceOf[ArrayData] + if (keyArrayData.numElements != valueArrayData.numElements) { + throw new RuntimeException("The given two arrays should have the same length") + } + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + if (leftArrayType.containsNull) { + var i = 0 + while (i < keyArrayData.numElements) { + if (keyArrayData.isNullAt(i)) { + throw new RuntimeException("Cannot use null as map key!") + } + i += 1 + } + } + new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { + val i = ctx.freshName("i") + s""" + |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { + | if ($keyArrayData.isNullAt($i)) { + | throw new RuntimeException("Cannot use null as map key!"); + | } + |} + """.stripMargin + } + s""" + |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { + | throw new RuntimeException("The given two arrays should have the same length"); + |} + |$keyArrayElemNullCheck + |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); + """.stripMargin + }) + } + + override def prettyName: String = "map_from_arrays" +} + /** * An expression representing a not yet available attribute name. This expression is unevaluable * and as its name suggests it is a temporary place holder until we're able to determine the @@ -314,8 +388,8 @@ trait CreateNamedStructLike extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" + + s" position, got: ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { @@ -373,7 +447,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc extraArguments = "Object[]" -> values :: Nil) ev.copy(code = - s""" + code""" |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3fba52d745453..8994eeff92c7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -68,7 +68,7 @@ object ExtractValue { case StructType(_) => s"Field name should be String Literal, but it's $extraction" case other => - s"Can't extract value from $child: need struct type but got ${other.simpleString}" + s"Can't extract value from $child: need struct type but got ${other.catalogString}" } throw new AnalysisException(errorMsg) } @@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy var i = 0 var found = false while (i < length && !found) { - if (keys.get(i, keyType) == ordinal) { + if (ordering.equiv(keys.get(i, keyType), ordinal)) { found = true } else { i += 1 @@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil with ExtractValue with NullIntolerant { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + private def keyType = child.dataType.asInstanceOf[MapType].keyType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName") + } + } + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) @@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression) // todo: current search is O(n), improve it. override def nullSafeEval(value: Any, ordinal: Any): Any = { - getValueEval(value, ordinal, keyType) + getValueEval(value, ordinal, keyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 205d77f6a9acf..bed581a61b2dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -32,7 +33,12 @@ import org.apache.spark.sql.types._ """) // scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends ComplexTypeMergingExpression { + + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + Seq(trueValue.dataType, falseValue.dataType) + } override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable @@ -41,17 +47,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + - s"not ${predicate.dataType.simpleString}") - } else if (!trueValue.dataType.sameType(falseValue.dataType)) { + s"not ${predicate.dataType.catalogString}") + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + - s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") + s"(${trueValue.dataType.catalogString} and ${falseValue.dataType.catalogString}).") } else { TypeCheckResult.TypeCheckSuccess } } - override def dataType: DataType = trueValue.dataType - override def eval(input: InternalRow): Any = { if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) @@ -66,7 +70,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.genCode(ctx) val code = - s""" + code""" |${condEval.code} |boolean ${ev.isNull} = false; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -117,27 +121,24 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen( branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression with Serializable { + extends ComplexTypeMergingExpression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) - - def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { - case Seq(dt1, dt2) => dt1.sameType(dt2) + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + branches.map(_._2.dataType) ++ elseValue.map(_.dataType) } - override def dataType: DataType = branches.head._2.dataType - override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } override def checkInputDataTypes(): TypeCheckResult = { - // Make sure all branch conditions are boolean types. - if (valueTypesEqual) { + if (TypeCoercion.haveSameType(inputTypesForMerging)) { + // Make sure all branch conditions are boolean types. if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess } else { @@ -265,7 +266,7 @@ case class CaseWhen( }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes @@ -293,7 +294,7 @@ object CaseWhen { case cond :: value :: Nil => Some((cond, value)) case value :: Nil => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } @@ -308,7 +309,7 @@ object CaseKeyWhen { case Seq(cond, value) => Some((EqualTo(key, cond), value)) case Seq(value) => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala new file mode 100644 index 0000000000000..2917b0b8c9c53 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} +import org.apache.spark.sql.types.DataType + +case class KnownNotNull(child: Expression) extends UnaryExpression { + override def nullable: Boolean = false + override def dataType: DataType = child.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx).copy(isNull = FalseLiteral) + } + + override def eval(input: InternalRow): Any = { + child.eval(input) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index b9b2cd5bdb9f0..f95798d64db19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -717,7 +718,7 @@ abstract class UnixTime } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -746,7 +747,7 @@ abstract class UnixTime }) case TimestampType => val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -757,7 +758,7 @@ abstract class UnixTime val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1016,6 +1017,48 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } } +/** + * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp, + * which requires the timestamp string to not have timezone information, otherwise null is returned. + */ +case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def dataType: DataType = TimestampType + override def nullable: Boolean = true + override def toString: String = child.toString + override def sql: String = child.sql + + override def nullSafeEval(input: Any): Any = { + DateTimeUtils.stringToTimestamp( + input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = ctx.addReferenceObj("timeZone", timeZone) + val longOpt = ctx.freshName("longOpt") + val eval = child.genCode(ctx) + val code = code""" + |${eval.code} + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; + |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; + |if (!${eval.isNull}) { + | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); + | if ($longOpt.isDefined()) { + | ${ev.value} = ((Long) $longOpt.get()).longValue(); + | ${ev.isNull} = false; + | } + |} + """.stripMargin + ev.copy(code = code) + } +} + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield @@ -1048,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1062,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1152,42 +1195,61 @@ case class AddMonths(startDate: Expression, numMonths: Expression) } /** - * Returns number of months between dates date1 and date2. + * Returns number of months between times `timestamp1` and `timestamp2`. + * If `timestamp1` is later than `timestamp2`, then the result is positive. + * If `timestamp1` and `timestamp2` are on the same day of month, or both + * are the last day of month, time of day will be ignored. Otherwise, the + * difference is calculated based on 31 days per month, and rounded to + * 8 digits unless roundOff=false. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timestamp1, timestamp2) - Returns number of months between `timestamp1` and `timestamp2`.", + usage = """ + _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result + is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both + are the last day of month, time of day will be ignored. Otherwise, the difference is + calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false. + """, examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); 3.94959677 + > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30', false); + 3.9495967741935485 """, since = "1.5.0") // scalastyle:on line.size.limit -case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { +case class MonthsBetween( + date1: Expression, + date2: Expression, + roundOff: Expression, + timeZoneId: Option[String] = None) + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) - def this(date1: Expression, date2: Expression) = this(date1, date2, None) + def this(date1: Expression, date2: Expression, roundOff: Expression) = + this(date1, date2, roundOff, None) - override def left: Expression = date1 - override def right: Expression = date2 + override def children: Seq[Expression] = Seq(date1, date2, roundOff) - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType) override def dataType: DataType = DoubleType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone) + override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { + DateTimeUtils.monthsBetween( + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r, $tz)""" + defineCodeGen(ctx, ev, (d1, d2, roundOff) => { + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" }) } @@ -1226,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1240,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1283,7 +1345,7 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr } def this(left: Expression) = { - // backwards compatability + // backwards compatibility this(left, None, Cast(left, DateType)) } @@ -1383,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1471,14 +1533,14 @@ case class TruncDate(date: Expression, format: Expression) """, examples = """ Examples: - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR'); - 2015-01-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM'); - 2015-03-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD'); - 2015-03-05T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR'); - 2015-03-05T09:00:00 + > SELECT _FUNC_('YEAR', '2015-03-05T09:32:05.359'); + 2015-01-01 00:00:00 + > SELECT _FUNC_('MM', '2015-03-05T09:32:05.359'); + 2015-03-01 00:00:00 + > SELECT _FUNC_('DD', '2015-03-05T09:32:05.359'); + 2015-03-05 00:00:00 + > SELECT _FUNC_('HOUR', '2015-03-05T09:32:05.359'); + 2015-03-05 09:00:00 """, since = "2.3.0") // scalastyle:on line.size.limit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index db1579ba28671..04de83343be71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ /** @@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + ev.copy(EmptyBlock) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45e..d6e67b9ac3d10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -155,8 +156,8 @@ case class Stack(children: Seq[Expression]) extends Generator { val j = (i - 1) % numFields if (children(i).dataType != elementSchema.fields(j).dataType) { return TypeCheckResult.TypeCheckFailure( - s"Argument ${j + 1} (${elementSchema.fields(j).dataType.simpleString}) != " + - s"Argument $i (${children(i).dataType.simpleString})") + s"Argument ${j + 1} (${elementSchema.fields(j).dataType.catalogString}) != " + + s"Argument $i (${children(i).dataType.catalogString})") } } TypeCheckResult.TypeCheckSuccess @@ -215,13 +216,39 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ev.copy(code = - s""" + code""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) } } +/** + * Replicate the row N times. N is specified as the first argument to the function. + * This is an internal function solely used by optimizer to rewrite EXCEPT ALL AND + * INTERSECT ALL queries. + */ +case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { + private lazy val numColumns = children.length - 1 // remove the multiplier value from output. + + override def elementSchema: StructType = + StructType(children.tail.zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val numRows = children.head.eval(input).asInstanceOf[Long] + val values = children.tail.map(_.eval(input)).toArray + Range.Long(0, numRows, 1).map { _ => + val fields = new Array[Any](numColumns) + for (col <- 0 until numColumns) { + fields.update(col, values(col)) + } + InternalRow(fields: _*) + } + } +} + /** * Wrapper around another generator to specify outer behavior. This is used to implement functions * such as explode_outer. This expression gets replaced during analysis. @@ -250,7 +277,7 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with case _ => TypeCheckResult.TypeCheckFailure( "input to function explode should be array or map type, " + - s"not ${child.dataType.simpleString}") + s"not ${child.dataType.catalogString}") } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -380,7 +407,7 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene case _ => TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should be array of struct type, " + - s"not ${child.dataType.simpleString}") + s"not ${child.dataType.catalogString}") } override def elementSchema: StructType = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ef790338bdd27..a754e87a17968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |$hashResultType ${ev.value} = $seed; |$codes """.stripMargin) @@ -403,14 +404,15 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val fieldsHash = fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx) } val hashResultType = CodeGenerator.javaType(dataType) - ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, hashResultType -> result), + arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result), returnType = hashResultType, makeSplitFunction = body => s""" @@ -418,6 +420,10 @@ abstract class HashExpression[E] extends Expression { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |$code + """.stripMargin } @tailrec @@ -674,7 +680,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes @@ -777,10 +783,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val childResult = ctx.freshName("childResult") val fieldsHash = fields.zipWithIndex.map { case (field, index) => val computeFieldHash = nullSafeElementHash( - input, index.toString, field.nullable, field.dataType, childResult, ctx) + tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx) s""" |$childResult = 0; |$computeFieldHash @@ -788,10 +795,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result), returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" @@ -800,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |${CodeGenerator.JAVA_INT} $childResult = 0; + |$code + """.stripMargin } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala new file mode 100644 index 0000000000000..2bb6b20b944d4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -0,0 +1,846 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods + +/** + * A named lambda variable. + */ +case class NamedLambdaVariable( + name: String, + dataType: DataType, + nullable: Boolean, + exprId: ExprId = NamedExpression.newExprId, + value: AtomicReference[Any] = new AtomicReference()) + extends LeafExpression + with NamedExpression + with CodegenFallback { + + override def qualifier: Seq[String] = Seq.empty + + override def newInstance(): NamedExpression = + copy(exprId = NamedExpression.newExprId, value = new AtomicReference()) + + override def toAttribute: Attribute = { + AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, Seq.empty) + } + + override def eval(input: InternalRow): Any = value.get + + override def toString: String = s"lambda $name#${exprId.id}$typeSuffix" + + override def simpleString: String = s"lambda $name#${exprId.id}: ${dataType.simpleString}" +} + +/** + * A lambda function and its arguments. A lambda function can be hidden when a user wants to + * process an completely independent expression in a [[HigherOrderFunction]], the lambda function + * and its variables are then only used for internal bookkeeping within the higher order function. + */ +case class LambdaFunction( + function: Expression, + arguments: Seq[NamedExpression], + hidden: Boolean = false) + extends Expression with CodegenFallback { + + override def children: Seq[Expression] = function +: arguments + override def dataType: DataType = function.dataType + override def nullable: Boolean = function.nullable + + lazy val bound: Boolean = arguments.forall(_.resolved) + + override def eval(input: InternalRow): Any = function.eval(input) +} + +object LambdaFunction { + val identity: LambdaFunction = { + val id = UnresolvedAttribute.quoted("id") + LambdaFunction(id, Seq(id)) + } +} + +/** + * A higher order function takes one or more (lambda) functions and applies these to some objects. + * The function produces a number of variables which can be consumed by some lambda function. + */ +trait HigherOrderFunction extends Expression with ExpectsInputTypes { + + override def nullable: Boolean = arguments.exists(_.nullable) + + override def children: Seq[Expression] = arguments ++ functions + + /** + * Arguments of the higher ordered function. + */ + def arguments: Seq[Expression] + + def argumentTypes: Seq[AbstractDataType] + + /** + * All arguments have been resolved. This means that the types and nullabilty of (most of) the + * lambda function arguments is known, and that we can start binding the lambda functions. + */ + lazy val argumentsResolved: Boolean = arguments.forall(_.resolved) + + /** + * Checks the argument data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `argumentsResolved == true`. + */ + def checkArgumentDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(arguments, argumentTypes) + } + + /** + * Functions applied by the higher order function. + */ + def functions: Seq[Expression] + + def functionTypes: Seq[AbstractDataType] + + override def inputTypes: Seq[AbstractDataType] = argumentTypes ++ functionTypes + + /** + * All inputs must be resolved and all functions must be resolved lambda functions. + */ + override lazy val resolved: Boolean = argumentsResolved && functions.forall { + case l: LambdaFunction => l.resolved + case _ => false + } + + /** + * Bind the lambda functions to the [[HigherOrderFunction]] using the given bind function. The + * bind function takes the potential lambda and it's (partial) arguments and converts this into + * a bound lambda function. + */ + def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction + + // Make sure the lambda variables refer the same instances as of arguments for case that the + // variables in instantiated separately during serialization or for some reason. + @transient lazy val functionsForEval: Seq[Expression] = functions.map { + case LambdaFunction(function, arguments, hidden) => + val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap + function.transformUp { + case variable: NamedLambdaVariable if argumentMap.contains(variable.exprId) => + argumentMap(variable.exprId) + } + } +} + +/** + * Trait for functions having as input one argument and one function. + */ +trait SimpleHigherOrderFunction extends HigherOrderFunction { + + def argument: Expression + + override def arguments: Seq[Expression] = argument :: Nil + + def argumentType: AbstractDataType + + override def argumentTypes(): Seq[AbstractDataType] = argumentType :: Nil + + def function: Expression + + override def functions: Seq[Expression] = function :: Nil + + def functionType: AbstractDataType = AnyDataType + + override def functionTypes: Seq[AbstractDataType] = functionType :: Nil + + def functionForEval: Expression = functionsForEval.head + + /** + * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method + * in order to save null-check code. + */ + protected def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = + sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") + + override def eval(inputRow: InternalRow): Any = { + val value = argument.eval(inputRow) + if (value == null) { + null + } else { + nullSafeEval(inputRow, value) + } + } +} + +trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { + override def argumentType: AbstractDataType = ArrayType +} + +trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { + override def argumentType: AbstractDataType = MapType +} + +/** + * Transform elements in an array using the transform function. This is similar to + * a `map` in functional programming. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in an array using the function.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x + 1); + array(2, 3, 4) + > SELECT _FUNC_(array(1, 2, 3), (x, i) -> x + i); + array(1, 3, 5) + """, + since = "2.4.0") +case class ArrayTransform( + argument: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { + val ArrayType(elementType, containsNull) = argument.dataType + function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + case _ => + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + } + + @transient lazy val (elementVar, indexVar) = { + val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function + val indexVar = if (tail.nonEmpty) { + Some(tail.head.asInstanceOf[NamedLambdaVariable]) + } else { + None + } + (elementVar, indexVar) + } + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] + val f = functionForEval + val result = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) + } + result.update(i, f.eval(inputRow)) + i += 1 + } + result + } + + override def prettyName: String = "transform" +} + +/** + * Filters entries in a map using the provided function. + */ +@ExpressionDescription( +usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", +examples = """ + Examples: + > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); + [1 -> 0, 3 -> -1] + """, +since = "2.4.0") +case class MapFilter( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + @transient lazy val (keyVar, valueVar) = { + val args = function.asInstanceOf[LambdaFunction].arguments + (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) + } + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val m = argumentValue.asInstanceOf[MapData] + val f = functionForEval + val retKeys = new mutable.ListBuffer[Any] + val retValues = new mutable.ListBuffer[Any] + m.foreach(keyType, valueType, (k, v) => { + keyVar.value.set(k) + valueVar.value.set(v) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + retKeys += k + retValues += v + } + }) + ArrayBasedMapData(retKeys.toArray, retValues.toArray) + } + + override def dataType: DataType = argument.dataType + + override def functionType: AbstractDataType = BooleanType + + override def prettyName: String = "map_filter" +} + +/** + * Filters the input array using the given lambda function. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Filters the input array using the given predicate.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); + array(1, 3) + """, + since = "2.4.0") +case class ArrayFilter( + argument: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def dataType: DataType = argument.dataType + + override def functionType: AbstractDataType = BooleanType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + + @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] + val f = functionForEval + val buffer = new mutable.ArrayBuffer[Any](arr.numElements) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + buffer += elementVar.value.get + } + i += 1 + } + new GenericArrayData(buffer) + } + + override def prettyName: String = "filter" +} + +/** + * Tests whether a predicate holds for one or more elements in the array. + */ +@ExpressionDescription(usage = + "_FUNC_(expr, pred) - Tests whether a predicate holds for one or more elements in the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0); + true + """, + since = "2.4.0") +case class ArrayExists( + argument: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def dataType: DataType = BooleanType + + override def functionType: AbstractDataType = BooleanType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + + @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] + val f = functionForEval + var exists = false + var i = 0 + while (i < arr.numElements && !exists) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + exists = true + } + i += 1 + } + exists + } + + override def prettyName: String = "exists" +} + +/** + * Applies a binary operator to a start value and all elements in the array. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr, start, merge, finish) - Applies a binary operator to an initial state and all + elements in the array, and reduces this to a single state. The final state is converted + into the final result by applying a finish function. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 0, (acc, x) -> acc + x); + 6 + > SELECT _FUNC_(array(1, 2, 3), 0, (acc, x) -> acc + x, acc -> acc * 10); + 60 + """, + since = "2.4.0") +case class ArrayAggregate( + argument: Expression, + zero: Expression, + merge: Expression, + finish: Expression) + extends HigherOrderFunction with CodegenFallback { + + def this(argument: Expression, zero: Expression, merge: Expression) = { + this(argument, zero, merge, LambdaFunction.identity) + } + + override def arguments: Seq[Expression] = argument :: zero :: Nil + + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: AnyDataType :: Nil + + override def functions: Seq[Expression] = merge :: finish :: Nil + + override def functionTypes: Seq[AbstractDataType] = zero.dataType :: AnyDataType :: Nil + + override def nullable: Boolean = argument.nullable || finish.nullable + + override def dataType: DataType = finish.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (!DataType.equalsStructurally( + zero.dataType, merge.dataType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure( + s"argument 3 requires ${zero.dataType.simpleString} type, " + + s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure + } + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { + // Be very conservative with nullable. We cannot be sure that the accumulator does not + // evaluate to null. So we always set nullable to true here. + val ArrayType(elementType, containsNull) = argument.dataType + val acc = zero.dataType -> true + val newMerge = f(merge, acc :: (elementType, containsNull) :: Nil) + val newFinish = f(finish, acc :: Nil) + copy(merge = newMerge, finish = newFinish) + } + + @transient lazy val LambdaFunction(_, + Seq(accForMergeVar: NamedLambdaVariable, elementVar: NamedLambdaVariable), _) = merge + @transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish + + override def eval(input: InternalRow): Any = { + val arr = argument.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val Seq(mergeForEval, finishForEval) = functionsForEval + accForMergeVar.value.set(zero.eval(input)) + var i = 0 + while (i < arr.numElements()) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + accForMergeVar.value.set(mergeForEval.eval(input)) + i += 1 + } + accForFinishVar.value.set(accForMergeVar.value.get) + finishForEval.eval(input) + } + } + + override def prettyName: String = "aggregate" +} + +/** + * Transform Keys for every entry of the map by applying the transform_keys function. + * Returns map with transformed key entries + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); + map(array(2, 3, 4), array(1, 2, 3)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + map(array(2, 4, 6), array(1, 2, 3)) + """, + since = "2.4.0") +case class TransformKeys( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] + val resultKeys = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + val result = functionForEval.eval(inputRow) + if (result == null) { + throw new RuntimeException("Cannot use null as map key!") + } + resultKeys.update(i, result) + i += 1 + } + new ArrayBasedMapData(resultKeys, map.valueArray()) + } + + override def prettyName: String = "transform_keys" +} + +/** + * Returns a map that applies the function to each value of the map. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); + map(array(1, 2, 3), array(2, 3, 4)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + map(array(1, 2, 3), array(2, 4, 6)) + """, + since = "2.4.0") +case class TransformValues( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction) + : TransformValues = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] + val resultValues = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + resultValues.update(i, functionForEval.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(map.keyArray(), resultValues) + } + + override def prettyName: String = "transform_values" +} + +/** + * Merges two given maps into a single map by applying function to the pair of values with + * the same key. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(map1, map2, function) - Merges two given maps into a single map by applying + function to the pair of values with the same key. For keys only presented in one map, + NULL will be passed as the value for the missing key. If an input map contains duplicated + keys, only the first entry of the duplicated key is passed into the lambda function. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); + {1:"ax",2:"by"} + """, + since = "2.4.0") +case class MapZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + def functionForEval: Expression = functionsForEval.head + + @transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType + + @transient lazy val MapType(rightKeyType, rightValueType, rightValueContainsNull) = right.dataType + + @transient lazy val keyType = + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, rightKeyType).get + + @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil + + override def functions: Seq[Expression] = function :: Nil + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { + val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) + copy(function = f(function, arguments)) + } + + override def checkArgumentDataTypes(): TypeCheckResult = { + super.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (leftKeyType.sameType(rightKeyType)) { + TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName") + } else { + TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + + s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + + s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].") + } + case failure => failure + } + } + + override def checkInputDataTypes(): TypeCheckResult = checkArgumentDataTypes() + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + nullSafeEval(input, value1, value2) + } + } + } + + @transient lazy val LambdaFunction(_, Seq( + keyVar: NamedLambdaVariable, + value1Var: NamedLambdaVariable, + value2Var: NamedLambdaVariable), + _) = function + + /** + * The function accepts two key arrays and returns a collection of keys with indexes + * to value arrays. Indexes are represented as an array of two items. This is a small + * optimization leveraging mutability of arrays. + */ + @transient private lazy val getKeysWithValueIndexes: + (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = { + if (TypeUtils.typeWithProperEquals(keyType)) { + getKeysWithIndexesFast + } else { + getKeysWithIndexesBruteForce + } + } + + private def assertSizeOfArrayBuffer(size: Int): Unit = { + if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to zip maps with $size " + + s"unique keys due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + } + + private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { + val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + hashMap.get(key) match { + case Some(indexes) => + if (indexes(z).isEmpty) { + indexes(z) = Some(i) + } + case None => + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + hashMap.put(key, indexes) + } + i += 1 + } + } + hashMap + } + + private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = { + val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + var found = false + var j = 0 + while (!found && j < arrayBuffer.size) { + val (bufferKey, indexes) = arrayBuffer(j) + if (ordering.equiv(bufferKey, key)) { + found = true + if(indexes(z).isEmpty) { + indexes(z) = Some(i) + } + } + j += 1 + } + if (!found) { + assertSizeOfArrayBuffer(arrayBuffer.size) + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + arrayBuffer += Tuple2(key, indexes) + } + i += 1 + } + } + arrayBuffer + } + + private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { + val mapData1 = value1.asInstanceOf[MapData] + val mapData2 = value2.asInstanceOf[MapData] + val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray()) + val size = keysWithIndexes.size + val keys = new GenericArrayData(new Array[Any](size)) + val values = new GenericArrayData(new Array[Any](size)) + val valueData1 = mapData1.valueArray() + val valueData2 = mapData2.valueArray() + var i = 0 + for ((key, Array(index1, index2)) <- keysWithIndexes) { + val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null) + val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null) + keyVar.value.set(key) + value1Var.value.set(v1) + value2Var.value.set(v2) + keys.update(i, key) + values.update(i, functionForEval.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(keys, values) + } + + override def prettyName: String = "map_zip_with" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); + array(('a', 1), ('b', 2), ('c', 3)) + > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y)); + array(4, 6) + > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); + array('ad', 'be', 'cf') + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + def functionForEval: Expression = functionsForEval.head + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil + + override def functions: Seq[Expression] = List(function) + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ZipWith = { + val ArrayType(leftElementType, _) = left.dataType + val ArrayType(rightElementType, _) = right.dataType + copy(function = f(function, + (leftElementType, true) :: (rightElementType, true) :: Nil)) + } + + @transient lazy val LambdaFunction(_, + Seq(leftElemVar: NamedLambdaVariable, rightElemVar: NamedLambdaVariable), _) = function + + override def eval(input: InternalRow): Any = { + val leftArr = left.eval(input).asInstanceOf[ArrayData] + if (leftArr == null) { + null + } else { + val rightArr = right.eval(input).asInstanceOf[ArrayData] + if (rightArr == null) { + null + } else { + val resultLength = math.max(leftArr.numElements(), rightArr.numElements()) + val f = functionForEval + val result = new GenericArrayData(new Array[Any](resultLength)) + var i = 0 + while (i < resultLength) { + if (i < leftArr.numElements()) { + leftElemVar.value.set(leftArr.get(i, leftElemVar.dataType)) + } else { + leftElemVar.value.set(null) + } + if (i < rightArr.numElements()) { + rightElemVar.value.set(rightArr.get(i, rightElemVar.dataType)) + } else { + rightElemVar.value.set(null) + } + result.update(i, f.eval(input)) + i += 1 + } + result + } + } + } + + override def prettyName: String = "zip_with" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 2a3cc580273ee..3b0141ad52cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + isNull = FalseLiteral) } } @@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index fdd672c416a03..11cc88735a9a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} +import java.io._ import scala.util.parsing.combinator.RegexParsers @@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -495,7 +495,7 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] or [[ArrayType]] of [[StructType]]s + * Converts an json input string to a [[StructType]], [[ArrayType]] or [[MapType]] * with the specified schema. */ // scalastyle:off line.size.limit @@ -527,31 +527,27 @@ case class JsonToStructs( override def nullable: Boolean = true // Used in `FunctionRegistry` - def this(child: Expression, schema: Expression) = + def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), - options = Map.empty[String, String], + schema = JsonExprUtils.evalSchemaExpr(schema), + options = options, child = child, timeZoneId = None) + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) + def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), + schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, timeZoneId = None) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { - case _: StructType | ArrayType(_: StructType, _) => + case _: StructType | _: ArrayType | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") - } - - @transient - lazy val rowSchema = nullableSchema match { - case st: StructType => st - case ArrayType(st: StructType, _) => st + s"Input schema ${nullableSchema.catalogString} must be a struct, an array or a map.") } // This converts parsed rows to the desired output by the given schema. @@ -559,14 +555,16 @@ case class JsonToStructs( lazy val converter = nullableSchema match { case _: StructType => (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null - case ArrayType(_: StructType, _) => - (rows: Seq[InternalRow]) => new GenericArrayData(rows) + case _: ArrayType => + (rows: Seq[InternalRow]) => rows.head.getArray(0) + case _: MapType => + (rows: Seq[InternalRow]) => rows.head.getMap(0) } @transient lazy val parser = new JacksonParser( - rowSchema, + nullableSchema, new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) override def dataType: DataType = nullableSchema @@ -607,6 +605,11 @@ case class JsonToStructs( } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def sql: String = schema match { + case _: MapType => "entries" + case _ => super.sql + } } /** @@ -622,7 +625,7 @@ case class JsonToStructs( {"a":1,"b":2} > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} - > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); + > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2))); [{"a":1,"b":2}] > SELECT _FUNC_(map('a', named_struct('b', 1))); {"a":{"b":1}} @@ -719,7 +722,7 @@ case class StructsToJson( TypeCheckResult.TypeCheckFailure(e.getMessage) } case _ => TypeCheckResult.TypeCheckFailure( - s"Input type ${child.dataType.simpleString} must be a struct, array of structs or " + + s"Input type ${child.dataType.catalogString} must be a struct, array of structs or " + "a map or array of map.") } @@ -731,11 +734,44 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } +/** + * A function infers schema of JSON string. + */ +@ExpressionDescription( + usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.", + examples = """ + Examples: + > SELECT _FUNC_('[{"col":0}]'); + array> + """, + since = "2.4.0") +case class SchemaOfJson(child: Expression) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + private val jsonOptions = new JSONOptions(Map.empty, "UTC") + private val jsonFactory = new JsonFactory() + jsonOptions.setJacksonOptions(jsonFactory) + + override def convert(v: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => + parser.nextToken() + inferField(parser, jsonOptions) + } + + UTF8String.fromString(dt.catalogString) + } +} + object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): StructType = exp match { - case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) - case e => throw new AnalysisException(s"Expected a string literal instead of $e") + def evalSchemaExpr(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) + case e @ SchemaOfJson(_: Literal) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal" + + s" or output of the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { @@ -747,7 +783,7 @@ object JsonExprUtils { } case m: CreateMap => throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType}") + s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") case _ => throw new AnalysisException("Must use a map() function for options") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 246025b82d59e..0efd1224f1bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -57,6 +57,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => @@ -185,7 +186,7 @@ object Literal { case map: MapType => create(Map(), map) case struct: StructType => create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) - case udt: UserDefinedType[_] => default(udt.sqlType) + case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt) case other => throw new RuntimeException(s"no default for type $dataType") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bc4cfcec47425..c2e1720259b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression, val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7eda65a867028..0cdeda9b10516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = s"""${eval.code} + ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, @@ -117,17 +118,21 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.", + usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""", examples = """ Examples: > SELECT _FUNC_(); 46707d92-02f4-4817-8116-a4c3b23e6266 - """) + """, + note = "The function is non-deterministic.") // scalastyle:on line.size.limit -case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful + with ExpressionWithRandomSeed { def this() = this(None) + override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false @@ -150,7 +155,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") - ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8df870468c2ad..584a2946bd564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -40,7 +40,16 @@ object NamedExpression { * * The `id` field is unique within a given JVM, while the `uuid` is used to uniquely identify JVMs. */ -case class ExprId(id: Long, jvmId: UUID) +case class ExprId(id: Long, jvmId: UUID) { + + override def equals(other: Any): Boolean = other match { + case ExprId(id, jvmId) => this.id == id && this.jvmId == jvmId + case _ => false + } + + override def hashCode(): Int = id.hashCode() + +} object ExprId { def apply(id: Long): ExprId = ExprId(id, NamedExpression.jvmId) @@ -62,19 +71,22 @@ trait NamedExpression extends Expression { * multiple qualifiers, it is possible that there are other possible way to refer to this * attribute. */ - def qualifiedName: String = (qualifier.toSeq :+ name).mkString(".") + def qualifiedName: String = (qualifier :+ name).mkString(".") /** * Optional qualifier for the expression. + * Qualifier can also contain the fully qualified information, for e.g, Sequence of string + * containing the database and the table name * * For now, since we do not allow using original table name to qualify a column name once the * table is aliased, this can only be: * * 1. Empty Seq: when an attribute doesn't have a qualifier, * e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation. - * 2. Single element: either the table name or the alias name of the table. + * 2. Seq with a Single element: either the table name or the alias name of the table. + * 3. Seq with 2 elements: database name and table name */ - def qualifier: Option[String] + def qualifier: Seq[String] def toAttribute: Attribute @@ -100,7 +112,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute - def withQualifier(newQualifier: Option[String]): Attribute + def withQualifier(newQualifier: Seq[String]): Attribute def withName(newName: String): Attribute def withMetadata(newMetadata: Metadata): Attribute @@ -121,14 +133,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. - * @param qualifier An optional string that can be used to referred to this attribute in a fully - * qualified way. Consider the examples tableName.name, subQueryAlias.name. - * tableName and subQueryAlias are possible qualifiers. + * @param qualifier An optional Seq of string that can be used to refer to this attribute in a + * fully qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None, + val qualifier: Seq[String] = Seq.empty, val explicitMetadata: Option[Metadata] = None) extends UnaryExpression with NamedExpression { @@ -192,7 +204,7 @@ case class Alias(child: Expression, name: String)( } override def sql: String = { - val qualifierPrefix = qualifier.map(_ + ".").getOrElse("") + val qualifierPrefix = if (qualifier.nonEmpty) qualifier.mkString(".") + "." else "" s"${child.sql} AS $qualifierPrefix${quoteIdentifier(name)}" } } @@ -216,9 +228,11 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None) + val qualifier: Seq[String] = Seq.empty[String]) extends Attribute with Unevaluable { + // currently can only handle qualifier of length 2 + require(qualifier.length <= 2) /** * Returns true iff the expression id is the same for both attributes. */ @@ -277,7 +291,7 @@ case class AttributeReference( /** * Returns a copy of this [[AttributeReference]] with new qualifier. */ - override def withQualifier(newQualifier: Option[String]): AttributeReference = { + override def withQualifier(newQualifier: Seq[String]): AttributeReference = { if (newQualifier == qualifier) { this } else { @@ -315,7 +329,7 @@ case class AttributeReference( override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" override def sql: String = { - val qualifierPrefix = qualifier.map(_ + ".").getOrElse("") + val qualifierPrefix = if (qualifier.nonEmpty) qualifier.mkString(".") + "." else "" s"$qualifierPrefix${quoteIdentifier(name)}" } } @@ -341,12 +355,12 @@ case class PrettyAttribute( override def withNullability(newNullability: Boolean): Attribute = throw new UnsupportedOperationException override def newInstance(): Attribute = throw new UnsupportedOperationException - override def withQualifier(newQualifier: Option[String]): Attribute = + override def withQualifier(newQualifier: Seq[String]): Attribute = throw new UnsupportedOperationException override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def withMetadata(newMetadata: Metadata): Attribute = throw new UnsupportedOperationException - override def qualifier: Option[String] = throw new UnsupportedOperationException + override def qualifier: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = true } @@ -362,7 +376,7 @@ case class OuterReference(e: NamedExpression) override def prettyName: String = "outer" override def name: String = e.name - override def qualifier: Option[String] = e.qualifier + override def qualifier: Seq[String] = e.qualifier override def exprId: ExprId = e.exprId override def toAttribute: Attribute = e.toAttribute override def newInstance(): NamedExpression = OuterReference(e.newInstance()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0787342bce6bc..b683d2a7e9ef3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -43,7 +44,7 @@ import org.apache.spark.sql.types._ 1 """) // scalastyle:on line.size.limit -case class Coalesce(children: Seq[Expression]) extends Expression { +case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = children.forall(_.nullable) @@ -60,8 +61,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator @@ -111,7 +110,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { @@ -232,7 +231,7 @@ case class IsNaN(child: Expression) extends UnaryExpression val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) @@ -278,7 +277,7 @@ case class NaNvl(left: Expression, right: Expression) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -440,7 +439,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bc17d1229420a..3189e6841a525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,8 +33,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -268,7 +270,7 @@ case class StaticInvoke( s"${ev.value} = $callFunc;" } - val code = s""" + val code = code""" $argCode $prepareIsNull $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -384,8 +386,7 @@ case class Invoke( """ } - val code = s""" - ${obj.code} + val code = obj.code + code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { @@ -491,7 +492,7 @@ case class NewInstance( s"new $className($argString)" } - val code = s""" + val code = code""" $argCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? @@ -531,9 +532,7 @@ case class UnwrapOption( val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); @@ -563,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); @@ -934,8 +931,7 @@ case class MapObjects private( ) } - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1040,11 +1036,13 @@ case class CatalystToExternalMap private( private lazy val valueConverter = CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) - private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + private lazy val (newMapBuilderMethod, moduleField) = { val clazz = Utils.classForName(collClass.getCanonicalName + "$") - val module = clazz.getField("MODULE$").get(null) - val method = clazz.getMethod("newBuilder") - method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) + } + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] } override def eval(input: InternalRow): Any = { @@ -1144,8 +1142,7 @@ case class CatalystToExternalMap private( """ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1252,8 +1249,72 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { + val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + def rowWrapper(data: Any): InternalRow = { + rowBuffer.update(0, data) + rowBuffer + } + + child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 + } + (keys, values) + } + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 + } + (keys, values) + } + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result != null) { + val (keys, values) = mapCatalystConverter(result) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } else { + null + } + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputMap = child.genCode(ctx) @@ -1324,9 +1385,8 @@ case class ExternalMapToCatalyst private( val mapCls = classOf[ArrayBasedMapData].getName val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) - val code = - s""" - ${inputMap.code} + val code = inputMap.code + + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); @@ -1404,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = - s""" + code""" |Object[] $values = new Object[${children.size}]; |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); @@ -1432,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ @@ -1465,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ @@ -1547,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp funcName = "initializeJavaBean", extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = - s""" - |${instanceGen.code} + val code = instanceGen.code + + code""" |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; |if (!${instanceGen.isNull}) { | $initializeCode @@ -1597,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // because errMsgField is used only when the value is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - val code = s""" - ${childGen.code} - + val code = childGen.code + code""" if (${childGen.isNull}) { throw new NullPointerException($errMsgField); } @@ -1642,7 +1697,7 @@ case class GetExternalRowField( // because errMsgField is used only when the field is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) - val code = s""" + val code = code""" ${row.code} if (${row.isNull}) { @@ -1670,12 +1725,35 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def nullable: Boolean = child.nullable - override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) + + private val errMsg = s" is not a valid external type for schema of ${expected.catalogString}" - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val checkType: (Any) => Boolean = expected match { + case _: DecimalType => + (value: Any) => { + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] + } + case _: ArrayType => + (value: Any) => { + value.getClass.isArray || value.isInstanceOf[Seq[_]] + } + case _ => + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + (value: Any) => { + dataTypeClazz.isInstance(value) + } + } - private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields @@ -1689,12 +1767,12 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" case _ => s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } - val code = s""" + val code = code""" ${input.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1a48995358af7..11dcc3ebf798c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalyst +import java.util.Locale + import com.google.common.collect.Maps +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -138,6 +142,123 @@ package object expressions { def indexOf(exprId: ExprId): Int = { Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } + + private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { + m.mapValues(_.distinct).map(identity) + } + + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { + unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) + } + + /** Map to use for qualified case insensitive attribute lookups with 2 part key */ + @transient private lazy val qualified: Map[(String, String), Seq[Attribute]] = { + // key is 2 part: table/alias and name + val grouped = attrs.filter(_.qualifier.nonEmpty).groupBy { + a => (a.qualifier.last.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Map to use for qualified case insensitive attribute lookups with 3 part key */ + @transient private val qualified3Part: Map[(String, String, String), Seq[Attribute]] = { + // key is 3 part: database name, table name and name + val grouped = attrs.filter(_.qualifier.length == 2).groupBy { a => + (a.qualifier.head.toLowerCase(Locale.ROOT), + a.qualifier.last.toLowerCase(Locale.ROOT), + a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st two parts are qualifier + // (i.e. database name and table name) and the 3rd part is the actual column name. + // + // For example, consider an example where "db1" is the database name, "a" is the table name + // and "b" is the column name and "c" is the struct field name. + // If the name parts is db1.a.b.c, then Attribute will match + // Attribute(b, qualifier("db1,"a")) and List("c") will be the second element + var matches: (Seq[Attribute], Seq[String]) = nameParts match { + case dbPart +: tblPart +: name +: nestedFields => + val key = (dbPart.toLowerCase(Locale.ROOT), + tblPart.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified3Part.get(key)).filter { + a => (resolver(dbPart, a.qualifier.head) && resolver(tblPart, a.qualifier.last)) + } + (attributes, nestedFields) + case all => + (Seq.empty, Seq.empty) + } + + // If there are no matches, then find matches for the given name assuming that + // the 1st part is a qualifier (i.e. table name, alias, or subquery alias) and the + // 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + if (matches._1.isEmpty) { + matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.last) + } + (attributes, nestedFields) + case all => + (Seq.empty[Attribute], Seq.empty[String]) + } + } + + // If none of attributes match database.table.column pattern or + // `table.column` pattern, we try to resolve it as a column. + val (candidates, nestedFields) = matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + + def name = UnresolvedAttribute(nameParts).name + candidates match { + case Seq(a) if nestedFields.nonEmpty => + // One match, but we also need to extract the requested nested field. + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and aliased it with the last part of the name. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) => + ExtractValue(e, Literal(name), resolver) + } + Some(Alias(fieldExprs, nestedFields.last)()) + + case Seq(a) => + // One match, no nested fields, use it. + Some(a) + + case Seq() => + // No matches. + None + + case ambiguousReferences => + // More than one match. + val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ") + throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e195ec17f3bcf..149bd79278a54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -36,6 +37,14 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + super.initialize(partitionIndex) + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** @@ -129,6 +138,66 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } +/** + * Evaluates to `true` if `values` are returned in `query`'s result set. + */ +case class InSubquery(values: Seq[Expression], query: ListQuery) + extends Predicate with Unevaluable { + + @transient lazy val value: Expression = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } + + + override def checkInputDataTypes(): TypeCheckResult = { + val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType, + ignoreNullability = true) + if (mismatchOpt) { + if (values.length != query.childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length}. + |#columns in right hand side: ${query.childOutputs.length}. + |Left side columns: + |[${values.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = values.zip(query.childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") + case _ => None + } + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + } + } + + override def children: Seq[Expression] = values :+ query + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ($query)" + override def sql: String = s"(${value.sql} IN (${query.sql}))" +} + /** * Evaluates to `true` if `list` contains `value`. @@ -160,44 +229,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) if (mismatchOpt.isDefined) { - list match { - case ListQuery(_, _, _, childOutputs) :: Nil => - val valExprs = value match { - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - if (valExprs.length != childOutputs.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${valExprs.length}. - |#columns in right hand side: ${childOutputs.length}. - |Left side columns: - |[${valExprs.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) - } else { - val mismatchedColumns = valExprs.zip(childOutputs).flatMap { - case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" - case _ => None - } - TypeCheckResult.TypeCheckFailure( - s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) - } - case _ => - TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType.simpleString} != ${mismatchOpt.get.dataType.simpleString}") - } + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } @@ -282,7 +315,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { }.mkString("\n")) ev.copy(code = - s""" + code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { @@ -298,9 +331,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def sql: String = { - val childrenSQL = children.map(_.sql) - val valueSQL = childrenSQL.head - val listSQL = childrenSQL.tail.mkString(", ") + val valueSQL = value.sql + val listSQL = list.map(_.sql).mkString(", ") s"($valueSQL IN ($listSQL))" } } @@ -346,7 +378,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with "" } ev.copy(code = - s""" + code""" |${childGen.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; @@ -398,7 +430,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = false; @@ -407,7 +439,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -462,7 +494,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = FalseLiteral - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = true; @@ -471,7 +503,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -613,7 +645,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.copy(code = eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 70186053617f8..b70c34141b97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -56,6 +57,14 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) } +/** + * Represents the behavior of expressions which have a random seed and can renew the seed. + * Usually the random seed needs to be renewed at each execution under streaming queries. + */ +trait ExpressionWithRandomSeed { + def withNewSeed(seed: Long): Expression +} + /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -68,12 +77,15 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful 0.8446490682263027 > SELECT _FUNC_(null); 0.8446490682263027 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit -case class Rand(child: Expression) extends RDG { +case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed { def this() = this(Literal(Utils.random.nextLong(), LongType)) + override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType)) + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -81,7 +93,7 @@ case class Rand(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -96,7 +108,7 @@ object Rand { /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.", + usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""", examples = """ Examples: > SELECT _FUNC_(); @@ -105,12 +117,15 @@ object Rand { 1.1164209726833079 > SELECT _FUNC_(null); 1.1164209726833079 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit -case class Randn(child: Expression) extends RDG { +case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed { def this() = this(Literal(Utils.random.nextLong(), LongType)) + override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType)) + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -118,7 +133,7 @@ case class Randn(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index ad0c0791d895f..bf0c35fe61018 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -271,7 +272,7 @@ case class StringSplit(str: Expression, pattern: Expression) usage = "_FUNC_(str, regexp, rep) - Replaces all substrings of `str` that match `regexp` with `rep`.", examples = """ Examples: - > SELECT _FUNC_('100-200', '(\d+)', 'num'); + > SELECT _FUNC_('100-200', '(\\d+)', 'num'); num-num """) // scalastyle:on line.size.limit @@ -370,7 +371,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio usage = "_FUNC_(str, regexp[, idx]) - Extracts a group that matches `regexp`.", examples = """ Examples: - > SELECT _FUNC_('100-200', '(\d+)-(\d+)', 1); + > SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1); 100 """) case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ea005a26a4c8b..14faa62bde7d0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -90,7 +91,7 @@ case class ConcatWs(children: Seq[Expression]) val args = ctx.freshName("args") val inputs = strings.zipWithIndex.map { case (eval, index) => - if (eval.isNull != "true") { + if (eval.isNull != TrueLiteral) { s""" ${eval.code} if (!${eval.isNull}) { @@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} $codes @@ -122,14 +123,14 @@ case class ConcatWs(children: Seq[Expression]) child.dataType match { case StringType => ("", // we count all the StringType arguments num at once below. - if (eval.isNull == "true") { + if (eval.isNull == TrueLiteral) { "" } else { s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") - if (eval.isNull == "true") { + if (eval.isNull == TrueLiteral) { ("", "") } else { (s""" @@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression]) foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( - s""" + code""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxVararg = 0; @@ -221,12 +222,13 @@ case class Elt(children: Seq[Expression]) extends Expression { val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) if (indexType != IntegerType) { return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + - s"have IntegerType, but it's $indexType") + s"have ${IntegerType.catalogString}, but it's ${indexType.catalogString}") } if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + s"input to function $prettyName should have ${StringType.catalogString} or " + + s"${BinaryType.catalogString}, but it's " + + inputTypes.map(_.catalogString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } @@ -288,7 +290,7 @@ case class Elt(children: Seq[Expression]) extends Expression { }.mkString) ev.copy( - s""" + code""" |${index.code} |final int $indexVal = ${index.value}; |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; @@ -654,7 +656,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -671,7 +673,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -754,7 +756,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -771,7 +773,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -856,7 +858,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -873,7 +875,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -1024,7 +1026,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -1350,7 +1352,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - ev.copy(code = s""" + ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1552,10 +1554,9 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run * A function that returns the char length of the given string expression or * number of bytes of the given binary expression. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of " + - "binary data. The length of string data includes the trailing spaces. The length of binary " + - "data includes binary zeros.", + usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of binary data. The length of string data includes the trailing spaces. The length of binary data includes binary zeros.", examples = """ Examples: > SELECT _FUNC_('Spark SQL '); @@ -1565,6 +1566,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run > SELECT CHARACTER_LENGTH('Spark SQL '); 10 """) +// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1915,12 +1917,15 @@ case class Encode(value: Expression, charset: Expression) usage = """ _FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2` decimal places. If `expr2` is 0, the result has no decimal point or fractional part. + `expr2` also accept a user specified format. This is supposed to function like MySQL's FORMAT. """, examples = """ Examples: > SELECT _FUNC_(12332.123456, 4); 12,332.1235 + > SELECT _FUNC_(12332.123456, '##################.###'); + 12332.123 """) case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -1929,14 +1934,20 @@ case class FormatNumber(x: Expression, d: Expression) override def right: Expression = d override def dataType: DataType = StringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(NumericType, TypeCollection(IntegerType, StringType)) + + private val defaultFormat = "#,###,###,###,###,###,##0" // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Option[Int] = None + private var lastDIntValue: Option[Int] = None + + @transient + private var lastDStringValue: Option[String] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @@ -1949,33 +1960,49 @@ case class FormatNumber(x: Expression, d: Expression) private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { - val dValue = dObject.asInstanceOf[Int] - if (dValue < 0) { - return null - } - - lastDValue match { - case Some(last) if last == dValue => - // use the current pattern - case _ => - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") - } + right.dataType match { + case IntegerType => + val dValue = dObject.asInstanceOf[Int] + if (dValue < 0) { + return null } - lastDValue = Some(dValue) + lastDIntValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append(defaultFormat) + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + + lastDIntValue = Some(dValue) - numberFormat.applyLocalizedPattern(pattern.toString) + numberFormat.applyLocalizedPattern(pattern.toString) + } + case StringType => + val dValue = dObject.asInstanceOf[UTF8String].toString + lastDStringValue match { + case Some(last) if last == dValue => + case _ => + pattern.delete(0, pattern.length) + lastDStringValue = Some(dValue) + if (dValue.isEmpty) { + numberFormat.applyLocalizedPattern(defaultFormat) + } else { + numberFormat.applyLocalizedPattern(dValue) + } + } } x.dataType match { @@ -2007,35 +2034,52 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val i = ctx.freshName("i") - val dFormat = ctx.freshName("dFormat") - val lastDValue = - ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") - val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") - s""" - if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); - - if ($d > 0) { - $pattern.append("."); - for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); + right.dataType match { + case IntegerType => + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val i = ctx.freshName("i") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") + s""" + if ($d >= 0) { + $pattern.delete(0, $pattern.length()); + if ($d != $lastDValue) { + $pattern.append("$defaultFormat"); + + if ($d > 0) { + $pattern.append("."); + for (int $i = 0; $i < $d; $i++) { + $pattern.append("0"); + } + } + $lastDValue = $d; + $numberFormat.applyLocalizedPattern($pattern.toString()); } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + } else { + ${ev.value} = null; + ${ev.isNull} = true; } - $lastDValue = $d; - $numberFormat.applyLocalizedPattern($pattern.toString()); - } - ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); - } else { - ${ev.value} = null; - ${ev.isNull} = true; - } - """ + """ + case StringType => + val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""") + val dValue = ctx.freshName("dValue") + s""" + String $dValue = $d.toString(); + if (!$dValue.equals($lastDValue)) { + $lastDValue = $dValue; + if ($dValue.isEmpty()) { + $numberFormat.applyLocalizedPattern("$defaultFormat"); + } else { + $numberFormat.applyLocalizedPattern($dValue); + } + } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + """ + } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 6acc87a3e7367..fc1caed84e272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -117,10 +117,10 @@ object SubExprUtils extends PredicateHelper { def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { splitConjunctivePredicates(condition).exists { case _: Exists | Not(_: Exists) => false - case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false + case _: InSubquery | Not(_: InSubquery) => false case e => e.find { x => x.isInstanceOf[Not] && e.find { - case In(_, Seq(_: ListQuery)) => true + case _: InSubquery => true case _ => false }.isDefined }.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 78895f1c2f6f5..707f312499734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,7 +21,8 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ /** @@ -70,9 +71,9 @@ case class WindowSpecDefinition( case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && !isValidFrameType(f.valueBoundary.head.dataType) => TypeCheckFailure( - s"The data type '${orderSpec.head.dataType.simpleString}' used in the order " + + s"The data type '${orderSpec.head.dataType.catalogString}' used in the order " + "specification does not match the data type " + - s"'${f.valueBoundary.head.dataType.simpleString}' which is used in the range frame.") + s"'${f.valueBoundary.head.dataType.catalogString}' which is used in the range frame.") case _ => TypeCheckSuccess } } @@ -251,7 +252,7 @@ case class SpecifiedWindowFrame( TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") case e: Expression if !frameType.inputType.acceptsType(e.dataType) => TypeCheckFailure( - s"The data type of the $location bound '${e.dataType.simpleString}' does not match " + + s"The data type of the $location bound '${e.dataType.catalogString}' does not match " + s"the expected data type '${frameType.inputType.simpleString}'.") case _ => TypeCheckSuccess } @@ -297,6 +298,37 @@ trait WindowFunction extends Expression { def frame: WindowFrame = UnspecifiedFrame } +/** + * Case objects that describe whether a window function is a SQL window function or a Python + * user-defined window function. + */ +sealed trait WindowFunctionType + +object WindowFunctionType { + case object SQL extends WindowFunctionType + case object Python extends WindowFunctionType + + def functionType(windowExpression: NamedExpression): WindowFunctionType = { + val t = windowExpression.collectFirst { + case _: WindowFunction | _: AggregateFunction => SQL + case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python + } + + // Normally a window expression would either have a SQL window function, a SQL + // aggregate function or a python window UDF. However, sometimes the optimizer will replace + // the window function if the value of the window function can be predetermined. + // For example, for query: + // + // select count(NULL) over () from values 1.0, 2.0, 3.0 T(a) + // + // The window function will be replaced by expression literal(0) + // To handle this case, if a window expression doesn't have a regular window function, we + // consider its type to be SQL as literal(0) is also a SQL expression. + t.getOrElse(SQL) + } +} + + /** * An offset window function is a window function that returns the value of the input column offset * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with @@ -342,7 +374,10 @@ abstract class OffsetWindowFunction override lazy val frame: WindowFrame = { val boundary = direction match { case Ascending => offset - case Descending => UnaryMinus(offset) + case Descending => UnaryMinus(offset) match { + case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + case o => o + } } SpecifiedWindowFrame(RowFrame, boundary, boundary) } @@ -442,7 +477,7 @@ abstract class RowNumberLike extends AggregateWindowFunction { protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil override val initialValues: Seq[Expression] = zero :: Nil - override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil + override val updateExpressions: Seq[Expression] = rowNumber + one :: Nil } /** @@ -493,7 +528,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must // return the same value for equal values in the partition. override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) + override val evaluateExpression = rowNumber.cast(DoubleType) / n.cast(DoubleType) override def prettyName: String = "cume_dist" } @@ -553,8 +588,7 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() private val bucketsWithPadding = AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() - private def bucketOverflow(e: Expression) = - If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + private def bucketOverflow(e: Expression) = If(rowNumber >= bucketThreshold, e, zero) override val aggBufferAttributes = Seq( rowNumber, @@ -568,15 +602,14 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow zero, zero, zero, - Cast(Divide(n, buckets), IntegerType), - Cast(Remainder(n, buckets), IntegerType) + (n / buckets).cast(IntegerType), + (n % buckets).cast(IntegerType) ) override val updateExpressions = Seq( - Add(rowNumber, one), - Add(bucket, bucketOverflow(one)), - Add(bucketThreshold, bucketOverflow( - Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + rowNumber + one, + bucket + bucketOverflow(one), + bucketThreshold + bucketOverflow(bucketSize + If(bucket < bucketsWithPadding, one, zero)), NoOp, NoOp ) @@ -610,7 +643,7 @@ abstract class RankLike extends AggregateWindowFunction { protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() protected val zero = Literal(0) protected val one = Literal(1) - protected val increaseRowNumber = Add(rowNumber, one) + protected val increaseRowNumber = rowNumber + one /** * Different RankLike implementations use different source expressions to update their rank value. @@ -619,7 +652,7 @@ abstract class RankLike extends AggregateWindowFunction { protected def rankSource: Expression = rowNumber /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ - protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + protected val increaseRank = If(orderEquals && rank =!= zero, rank, rankSource) override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs override val initialValues = zero +: one +: orderInit @@ -673,7 +706,7 @@ case class Rank(children: Seq[Expression]) extends RankLike { case class DenseRank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) - override protected def rankSource = Add(rank, one) + override protected def rankSource = rank + one override val updateExpressions = increaseRank +: children override val aggBufferAttributes = rank +: orderAttrs override val initialValues = zero +: orderInit @@ -702,8 +735,7 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase def this() = this(Nil) override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) override def dataType: DataType = DoubleType - override val evaluateExpression = If(GreaterThan(n, one), - Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), - Literal(0.0d)) + override val evaluateExpression = + If(n > one, (rank - one).cast(DoubleType) / (n - one).cast(DoubleType), 0.0d) override def prettyName: String = "percent_rank" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index a3cc4529b5456..deceec73dda30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -47,6 +47,22 @@ sealed trait IdentifierWithDatabase { override def toString: String = quotedString } +/** + * Encapsulates an identifier that is either a alias name or an identifier that has table + * name and optionally a database name. + * The SubqueryAlias node keeps track of the qualifier using the information in this structure + * @param identifier - Is an alias name or a table name + * @param database - Is a database name and is optional + */ +case class AliasIdentifier(identifier: String, database: Option[String]) + extends IdentifierWithDatabase { + + def this(identifier: String) = this(identifier, None) +} + +object AliasIdentifier { + def apply(identifier: String): AliasIdentifier = new AliasIdentifier(identifier) +} /** * Identifies a table in a database. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 025a388aacaa5..3e8e6db1dbd22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -18,10 +18,14 @@ package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} +import java.nio.channels.Channels +import java.nio.charset.Charset import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text +import sun.nio.cs.StreamDecoder +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.unsafe.types.UTF8String private[sql] object CreateJacksonParser extends Serializable { @@ -43,7 +47,48 @@ private[sql] object CreateJacksonParser extends Serializable { jsonFactory.createParser(record.getBytes, 0, record.getLength) } - def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(record) + // Jackson parsers can be ranked according to their performance: + // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser + // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically + // by checking leading bytes of the array. + // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected + // automatically by analyzing first bytes of the input stream. + // 3. Reader based parser. This is the slowest parser used here but it allows to create + // a reader with specific encoding. + // The method creates a reader for an array with given encoding and sets size of internal + // decoding buffer according to size of input array. + private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = { + val bais = new ByteArrayInputStream(in, 0, length) + val byteChannel = Channels.newChannel(bais) + val decodingBufferSize = Math.min(length, 8192) + val decoder = Charset.forName(enc).newDecoder() + + StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize) + } + + def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = { + val sd = getStreamDecoder(enc, record.getBytes, record.getLength) + jsonFactory.createParser(sd) + } + + def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = { + jsonFactory.createParser(is) + } + + def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { + jsonFactory.createParser(new InputStreamReader(is, enc)) + } + + def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val ba = row.getBinary(0) + + jsonFactory.createParser(ba, 0, ba.length) + } + + def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val binary = row.getBinary(0) + val sd = getStreamDecoder(enc, binary, binary.length) + + jsonFactory.createParser(sd) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5c9adc3332bc0..47eeb70e00427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.nio.charset.StandardCharsets +import java.nio.charset.{Charset, StandardCharsets} import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util._ * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { @@ -73,6 +73,9 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + // Whether to ignore column of all null values or empty array/struct during schema inference + val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) @@ -86,14 +89,28 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + /** + * A string between two consecutive JSON records. + */ val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => require(sep.nonEmpty, "'lineSep' cannot be an empty string.") sep } - // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) - // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + + protected def checkedEncoding(enc: String): String = enc + + /** + * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. + * If the encoding is not specified (None) in read, it will be detected automatically + * when the multiLine option is set to `true`. If encoding is not specified in write, + * UTF-8 is used by default. + */ + val encoding: Option[String] = parameters.get("encoding") + .orElse(parameters.get("charset")).map(checkedEncoding) + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.getOrElse("UTF-8")) + } val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") /** Sets config options on a Jackson [[JsonFactory]]. */ @@ -108,3 +125,46 @@ private[sql] class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars) } } + +private[sql] class JSONOptionsInRead( + @transient override val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) { + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + protected override def checkedEncoding(enc: String): String = { + val isBlacklisted = JSONOptionsInRead.blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + |Blacklist: ${JSONOptionsInRead.blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } +} + +private[sql] object JSONOptionsInRead { + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq( + Charset.forName("UTF-16"), + Charset.forName("UTF-32") + ) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9c413de752a8c..738947766adda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -45,14 +45,14 @@ private[sql] class JacksonGenerator( // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], - "JacksonGenerator only supports to be initialized with a StructType " + - s"or MapType but got ${dataType.simpleString}") + s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString} " + + s"or ${MapType.simpleString} but got ${dataType.catalogString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { case st: StructType => st.map(_.dataType).map(makeWriter).toArray case _ => throw new UnsupportedOperationException( - s"Initial type ${dataType.simpleString} must be a struct") + s"Initial type ${dataType.catalogString} must be a struct") } // `ValueWriter` for array data storing rows of the schema. @@ -70,7 +70,7 @@ private[sql] class JacksonGenerator( private lazy val mapElementWriter: ValueWriter = dataType match { case mt: MapType => makeWriter(mt.valueType) case _ => throw new UnsupportedOperationException( - s"Initial type ${dataType.simpleString} must be a map") + s"Initial type ${dataType.catalogString} must be a map") } private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7f6956994f31f..984979ac5e9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.json -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, CharConversionException} +import java.nio.charset.MalformedInputException import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -36,7 +37,7 @@ import org.apache.spark.util.Utils * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( - schema: StructType, + schema: DataType, val options: JSONOptions) extends Logging { import JacksonUtils._ @@ -57,7 +58,15 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { + private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = { + dt match { + case st: StructType => makeStructRootConverter(st) + case mt: MapType => makeMapRootConverter(mt) + case at: ArrayType => makeArrayRootConverter(at) + } + } + + private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { val elementConverter = makeConverter(st) val fieldConverters = st.map(_.dataType).map(makeConverter).toArray (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) { @@ -87,6 +96,42 @@ class JacksonParser( } } + private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = { + val fieldConverter = makeConverter(mt.valueType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) { + case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter))) + } + } + + private def makeArrayRootConverter(at: ArrayType): JsonParser => Seq[InternalRow] = { + val elemConverter = makeConverter(at.elementType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, at) { + case START_ARRAY => Seq(InternalRow(convertArray(parser, elemConverter))) + case START_OBJECT if at.elementType.isInstanceOf[StructType] => + // This handles the case when an input JSON object is a structure but + // the specified schema is an array of structures. In that case, the input JSON is + // considered as an array of only one element of struct type. + // This behavior was introduced by changes for SPARK-19595. + // + // For example, if the specified schema is ArrayType(new StructType().add("i", IntegerType)) + // and JSON input as below: + // + // [{"i": 1}, {"i": 2}] + // [{"i": 3}] + // {"i": 4} + // + // The last row is considered as an array with one element, and result of conversion: + // + // Seq(Row(1), Row(2)) + // Seq(Row(3)) + // Seq(Row(4)) + // + val st = at.elementType.asInstanceOf[StructType] + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + Seq(InternalRow(new GenericArrayData(Seq(convertObject(parser, st, fieldConverters))))) + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. @@ -129,7 +174,8 @@ class JacksonParser( case "NaN" => Float.NaN case "Infinity" => Float.PositiveInfinity case "-Infinity" => Float.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") + case other => throw new RuntimeException( + s"Cannot parse $other as ${FloatType.catalogString}.") } } @@ -144,7 +190,8 @@ class JacksonParser( case "NaN" => Double.NaN case "Infinity" => Double.PositiveInfinity case "-Infinity" => Double.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") + case other => + throw new RuntimeException(s"Cannot parse $other as ${DoubleType.catalogString}.") } } @@ -356,11 +403,19 @@ class JacksonParser( } } } catch { - case e @ (_: RuntimeException | _: JsonProcessingException) => + case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) => // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index 134d16e981a15..f26b194e7a7ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -52,7 +52,7 @@ object JacksonUtils { case _ => throw new UnsupportedOperationException( - s"Unable to convert column $name of type ${dataType.simpleString} to JSON.") + s"Unable to convert column $name of type ${dataType.catalogString} to JSON.") } schema.foreach(field => verifyType(field.name, field.dataType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index a270a6451d5dd..9999a005106f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.util.Comparator @@ -25,8 +25,8 @@ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,8 +45,9 @@ private[sql] object JsonInferSchema { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - // perform schema inference on each row and merge afterwards - val rootType = json.mapPartitions { iter => + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -66,11 +67,22 @@ private[sql] object JsonInferSchema { s"Parse Mode: ${FailFastMode.name}.", e) } } + }.reduceOption(typeMerger).toIterator + } + + // Here we manually submit a fold-like Spark job, so that we can set the SQLConf when running + // the fold functions in the scheduler event loop thread. + val existingConf = SQLConf.get + var rootType: DataType = StructType(Nil) + val foldPartition = (iter: Iterator[DataType]) => iter.fold(StructType(Nil))(typeMerger) + val mergeResult = (index: Int, taskResult: DataType) => { + rootType = SQLConf.withExistingConf(existingConf) { + typeMerger(rootType, taskResult) } - }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + } + json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult) - canonicalizeType(rootType) match { + canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -98,7 +110,7 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType @@ -176,33 +188,33 @@ private[sql] object JsonInferSchema { } /** - * Convert NullType to StringType and remove StructTypes with no fields + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. */ - private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => - for { - canonicalType <- canonicalizeType(elementType) - } yield { - at.copy(canonicalType) - } + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) case StructType(fields) => - val canonicalFields: Array[StructField] = for { - field <- fields - if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType) - } yield { - field.copy(dataType = canonicalType) + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) } - - if (canonicalFields.length > 0) { - Some(StructType(canonicalFields)) + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None } else { - // per SPARK-8093: empty structs should be deleted + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { None + } else { + Some(StringType) } - case NullType => Some(StringType) case other => Some(other) } @@ -290,8 +302,10 @@ private[sql] object JsonInferSchema { // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. - assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") - assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") + assert(isSorted(fields1), + s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), + s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}") val newFields = new java.util.ArrayList[StructField]() @@ -329,8 +343,8 @@ private[sql] object JsonInferSchema { ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in - // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when - // the given `DecimalType` is not capable of the given `IntegralType`. + // `findTightestCommonType`. Both cases below will be executed only when the given + // `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 913354e4df0e6..e4b4f1ecbe21f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -46,7 +46,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) - def batches: Seq[Batch] = { + /** + * Defines the default rule batches in the Optimizer. + * + * Implementations of this class should override this method, and [[nonExcludableRules]] if + * necessary, instead of [[batches]]. The rule batches that eventually run in the Optimizer, + * i.e., returned by [[batches]], will be (defaultBatches - (excludedRules - nonExcludableRules)). + */ + def defaultBatches: Seq[Batch] = { val operatorOptimizationRuleSet = Seq( // Operator push down @@ -123,11 +130,21 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + // run this once earlier. this might simplify the plan and reduce cost of optimizer. + // for example, a query such as Filter(LocalRelation) would go through all the heavy + // optimizer rules that are triggered when there is a filter + // (e.g. InferFiltersFromConstraints). if we run this batch earlier, the query becomes just + // LocalRelation and does not trigger many rules + Batch("LocalRelation early", fixedPoint, + ConvertToLocalRelation, + PropagateEmptyRelation) :: Batch("Pullup Correlated Expressions", Once, PullupCorrelatedPredicates) :: Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, + RewriteExceptAll, + RewriteIntersectAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, @@ -160,14 +177,51 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) UpdateNullabilityInAttributeReferences) } + /** + * Defines rules that cannot be excluded from the Optimizer even if they are specified in + * SQL config "excludedRules". + * + * Implementations of this class can override this method if necessary. The rule batches + * that eventually run in the Optimizer, i.e., returned by [[batches]], will be + * (defaultBatches - (excludedRules - nonExcludableRules)). + */ + def nonExcludableRules: Seq[String] = + EliminateDistinct.ruleName :: + EliminateSubqueryAliases.ruleName :: + EliminateView.ruleName :: + ReplaceExpressions.ruleName :: + ComputeCurrentTime.ruleName :: + GetCurrentDatabase(sessionCatalog).ruleName :: + RewriteDistinctAggregates.ruleName :: + ReplaceDeduplicateWithAggregate.ruleName :: + ReplaceIntersectWithSemiJoin.ruleName :: + ReplaceExceptWithFilter.ruleName :: + ReplaceExceptWithAntiJoin.ruleName :: + RewriteExceptAll.ruleName :: + RewriteIntersectAll.ruleName :: + ReplaceDistinctWithAggregate.ruleName :: + PullupCorrelatedPredicates.ruleName :: + RewriteCorrelatedScalarSubquery.ruleName :: + RewritePredicateSubquery.ruleName :: Nil + /** * Optimize all the subqueries inside expression. */ object OptimizeSubqueries extends Rule[LogicalPlan] { + private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = { + plan match { + case Sort(_, _, child) => child + case Project(fields, child) => Project(fields, removeTopLevelSort(child)) + case other => other + } + } def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) - s.withNewPlan(newPlan) + // At this point we have an optimized subquery plan that we are going to attach + // to this subquery expression. Here we can safely remove any top level sort + // in the plan as tuples produced by a subquery are un-ordered. + s.withNewPlan(removeTopLevelSort(newPlan)) } } @@ -175,6 +229,48 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) * Override to provide additional rules for the operator optimization batch. */ def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that + * eventually run in the Optimizer. + * + * Implementations of this class should override [[defaultBatches]], and [[nonExcludableRules]] + * if necessary, instead of this method. + */ + final override def batches: Seq[Batch] = { + val excludedRulesConf = + SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq) + val excludedRules = excludedRulesConf.filter { ruleName => + val nonExcludable = nonExcludableRules.contains(ruleName) + if (nonExcludable) { + logWarning(s"Optimization rule '${ruleName}' was not excluded from the optimizer " + + s"because this rule is a non-excludable rule.") + } + !nonExcludable + } + if (excludedRules.isEmpty) { + defaultBatches + } else { + defaultBatches.flatMap { batch => + val filteredRules = batch.rules.filter { rule => + val exclude = excludedRules.contains(rule.ruleName) + if (exclude) { + logInfo(s"Optimization rule '${rule.ruleName}' is excluded from the optimizer.") + } + !exclude + } + if (batch.rules == filteredRules) { + Some(batch) + } else if (filteredRules.nonEmpty) { + Some(Batch(batch.name, batch.strategy, filteredRules: _*)) + } else { + logInfo(s"Optimization batch '${batch.name}' is excluded from the optimizer " + + s"as all enclosed rules have been excluded.") + None + } + } + } + } } /** @@ -450,13 +546,16 @@ object ColumnPruning extends Rule[LogicalPlan] { case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => d.copy(child = prunedChild(child, d.references)) - // Prunes the unused columns from child of Aggregate/Expand/Generate + // Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) + case s @ ScriptTransformation(_, _, _, child, _) + if (child.outputSet -- s.references).nonEmpty => + s.copy(child = prunedChild(child, s.references)) // prune unrequired references case p @ Project(_, g: Generate) if p.references != g.outputSet => @@ -526,9 +625,10 @@ object ColumnPruning extends Rule[LogicalPlan] { /** * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, - * so remove it. + * so remove it. Since the Projects have been added top-down, we need to remove in bottom-up + * order, otherwise lower Projects can be missed. */ - private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { + private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) @@ -621,12 +721,15 @@ object CollapseRepartition extends Rule[LogicalPlan] { /** * Collapse Adjacent Window Expression. * - If the partition specs and order specs are the same and the window expression are - * independent, collapse into the parent. + * independent and are of the same window function type, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) - if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty && + // This assumes Window contains the same type of window expressions. This is ensured + // by ExtractWindowFunctions. + WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) => w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } @@ -637,13 +740,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * In addition, for left/right outer joins, infer predicate from the preserved side of the Join - * operator and push the inferred filter over to the null-supplying side. For example, if the - * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in - * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will - * be applied to the null-supplying side. + * Note: While this optimization is applicable to a lot of types of join, it primarily benefits + * Inner and LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -664,53 +765,52 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } case join @ Join(left, right, joinType, conditionOpt) => - // Only consider constraints that can be pushed down completely to either the left or the - // right child - val constraints = join.allConstraints.filter { c => - c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } - // Remove those constraints that are already enforced by either the left or the right child - val additionalConstraints = constraints -- (left.constraints ++ right.constraints) - val newConditionOpt = conditionOpt match { - case Some(condition) => - val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt - case None => - additionalConstraints.reduceOption(And) - } - // Infer filter for left/right outer joins - val newLeftOpt = joinType match { - case RightOuter if newConditionOpt.isDefined => - val inferredConstraints = left.getRelevantConstraints( - left.constraints - .union(right.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(left.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, left)) - case _ => None - } - val newRightOpt = joinType match { - case LeftOuter if newConditionOpt.isDefined => - val inferredConstraints = right.getRelevantConstraints( - right.constraints - .union(left.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(right.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, right)) - case _ => None - } + joinType match { + // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an + // inner join, it just drops the right side in the final output. + case _: InnerLike | LeftSemi => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + val newRight = inferNewFilter(right, allConstraints) + join.copy(left = newLeft, right = newRight) - if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) - || newLeftOpt.isDefined || newRightOpt.isDefined) { - Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) - } else { - join + // For right outer join, we can only infer additional filters for left side. + case RightOuter => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + join.copy(left = newLeft) + + // For left join, we can only infer additional filters for right side. + case LeftOuter | LeftAnti => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newRight = inferNewFilter(right, allConstraints) + join.copy(right = newRight) + + case _ => join } } + + private def getAllConstraints( + left: LogicalPlan, + right: LogicalPlan, + conditionOpt: Option[Expression]): Set[Expression] = { + val baseConstraints = left.constraints.union(right.constraints) + .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + baseConstraints.union(inferAdditionalConstraints(baseConstraints)) + } + + private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { + val newPredicates = constraints + .union(constructIsNotNullConstraints(constraints, plan.output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic + } -- plan.constraints + if (newPredicates.isEmpty) { + plan + } else { + Filter(newPredicates.reduce(And), plan) + } + } } /** @@ -770,12 +870,29 @@ object EliminateSorts extends Rule[LogicalPlan] { } /** - * Removes Sort operation if the child is already sorted + * Removes redundant Sort operation. This can happen: + * 1) if the child is already sorted + * 2) if there is another Sort operator separated by 0...n Project/Filter operators */ object RemoveRedundantSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => child + case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + } + + def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { + case Sort(_, _, child) => recursiveRemoveSort(child) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(recursiveRemoveSort)) + case _ => plan + } + + def canEliminateSort(plan: LogicalPlan): Boolean = plan match { + case p: Project => p.projectList.forall(_.deterministic) + case f: Filter => f.condition.deterministic + case _: ResolvedHint => true + case _ => false } } @@ -1168,12 +1285,14 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( - s"""Detected cartesian product for ${j.joinType.sql} join between logical plans + s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans |${left.treeString(false).trim} |and |${right.treeString(false).trim} |Join condition is missing or trivial. - |Use the CROSS JOIN syntax to allow cartesian products between these relations.""" + |Either: use the CROSS JOIN syntax to allow cartesian products between these + |relations, or: enable implicit cartesian products by setting the configuration + |variable spark.sql.crossJoin.enabled=true""" .stripMargin) } } @@ -1238,6 +1357,12 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming)) => LocalRelation(output, data.take(limit), isStreaming) + + case Filter(condition, LocalRelation(output, data, isStreaming)) + if !hasUnevaluableExpr(condition) => + val predicate = InterpretedPredicate.create(condition, output) + predicate.initialize(0) + LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming) } private def hasUnevaluableExpr(expr: Expression): Boolean = { @@ -1295,7 +1420,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { */ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Intersect(left, right) => + case Intersect(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) @@ -1316,13 +1441,149 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { */ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Except(left, right) => + case Except(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And))) } } +/** + * Replaces logical [[Except]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_rows(sum_val, c1) + * FROM ( + * SELECT c1, sum_val + * FROM ( + * SELECT c1, sum(vcol) AS sum_val + * FROM ( + * SELECT 1L as vcol, c1 FROM ut1 + * UNION ALL + * SELECT -1L as vcol, c1 FROM ut2 + * ) AS union_all + * GROUP BY union_all.c1 + * ) + * WHERE sum_val > 0 + * ) + * ) + * }}} + */ + +object RewriteExceptAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Except(left, right, true) => + assert(left.output.size == right.output.size) + + val newColumnLeft = Alias(Literal(1L), "vcol")() + val newColumnRight = Alias(Literal(-1L), "vcol")() + val modifiedLeftPlan = Project(Seq(newColumnLeft) ++ left.output, left) + val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, right) + val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan) + val aggSumCol = + Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), Complete, false), "sum")() + val aggOutputColumns = left.output ++ Seq(aggSumCol) + val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan) + val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, Literal(0L)), aggregatePlan) + val genRowPlan = Generate( + ReplicateRows(Seq(aggSumCol.toAttribute) ++ left.output), + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + left.output, + filteredAggPlan + ) + Project(left.output, genRowPlan) + } +} + +/** + * Replaces logical [[Intersect]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_row(min_count, c1) + * FROM ( + * SELECT c1, If (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count + * FROM ( + * SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt + * FROM ( + * SELECT true as vcol1, null as , c1 FROM ut1 + * UNION ALL + * SELECT null as vcol1, true as vcol2, c1 FROM ut2 + * ) AS union_all + * GROUP BY c1 + * HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1 + * ) + * ) + * ) + * }}} + */ +object RewriteIntersectAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right, true) => + assert(left.output.size == right.output.size) + + val trueVcol1 = Alias(Literal(true), "vcol1")() + val nullVcol1 = Alias(Literal(null, BooleanType), "vcol1")() + + val trueVcol2 = Alias(Literal(true), "vcol2")() + val nullVcol2 = Alias(Literal(null, BooleanType), "vcol2")() + + // Add a projection on the top of left and right plans to project out + // the additional virtual columns. + val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left) + val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right) + + val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols) + + // Expressions to compute count and minimum of both the counts. + val vCol1AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")() + val vCol2AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")() + val ifExpression = Alias(If( + GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute), + vCol2AggrExpr.toAttribute, + vCol1AggrExpr.toAttribute + ), "min_count")() + + val aggregatePlan = Aggregate(left.output, + Seq(vCol1AggrExpr, vCol2AggrExpr) ++ left.output, unionPlan) + val filterPlan = Filter(And(GreaterThanOrEqual(vCol1AggrExpr.toAttribute, Literal(1L)), + GreaterThanOrEqual(vCol2AggrExpr.toAttribute, Literal(1L))), aggregatePlan) + val projectMinPlan = Project(left.output ++ Seq(ifExpression), filterPlan) + + // Apply the replicator to replicate rows based on min_count + val genRowPlan = Generate( + ReplicateRows(Seq(ifExpression.toAttribute) ++ left.output), + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + left.output, + projectMinPlan + ) + Project(left.output, genRowPlan) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 45edf266bbce4..efd3944eba7f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,7 +46,7 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case e @ Except(left, right) if isEligible(left, right) => + case e @ Except(left, right, false) if isEligible(left, right) => val newCondition = transformCondition(left, skipProject(right)) newCondition.map { c => Distinct(Filter(Not(c), left)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1c0b7bd806801..5629b72894225 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -219,15 +218,24 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case In(v, list) if list.isEmpty => + // When v is not nullable, the following expression will be optimized + // to FalseLiteral which is tested in OptimizeInSuite.scala + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.length == 1 + // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, + // TODO: we exclude them in this rule. + && !v.isInstanceOf[CreateNamedStructLike] + && !newList.head.isInstanceOf[CreateNamedStructLike]) { + EqualTo(v, newList.head) + } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.size < list.size) { + } else if (newList.length < list.length) { expr.copy(list = newList) - } else { // newList.length == list.length + } else { // newList.length == list.length && newList.length > 1 expr } } @@ -382,6 +390,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue + case If(cond, trueValue, falseValue) + if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. @@ -395,17 +405,35 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = newBranches) } - case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) => + case CaseWhen(branches, _) if branches.headOption.map(_._1).contains(TrueLiteral) => // If the first branch is a true literal, remove the entire CaseWhen and use the value // from that. Note that CaseWhen.branches should never be empty, and as a result the // headOption (rather than head) added above is just an extra (and unnecessary) safeguard. branches.head._2 case CaseWhen(branches, _) if branches.exists(_._1 == TrueLiteral) => - // a branc with a TRue condition eliminates all following branches, + // a branch with a true condition eliminates all following branches, // these branches can be pruned away val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) + + case e @ CaseWhen(branches, Some(elseValue)) + if branches.forall(_._2.semanticEquals(elseValue)) => + // For non-deterministic conditions with side effect, we can not remove it, or change + // the ordering. As a result, we try to remove the deterministic conditions from the tail. + var hitNonDeterministicCond = false + var i = branches.length + while (i > 0 && !hitNonDeterministicCond) { + hitNonDeterministicCond = !branches(i - 1)._1.deterministic + if (!hitNonDeterministicCond) { + i -= 1 + } + } + if (i == 0) { + elseValue + } else { + e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) + } } } } @@ -495,6 +523,7 @@ object NullPropagation extends Rule[LogicalPlan] { // If the value expression is NULL then transform the In expression to null literal. case In(Literal(null, _), _) => Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. @@ -643,6 +672,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } + /** * Combine nested [[Concat]] expressions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 709db6d8bec7d..e9b7a8b76e683 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -42,13 +43,6 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def getValueExpression(e: Expression): Seq[Expression] = { - e match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match { // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, // the produced join then becomes unresolved and break structural integrity. We should @@ -97,34 +91,35 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => + case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: // (a1,a2,...) = (b1,b2,...) // to // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ... - val joinConds = splitConjunctivePredicates(joinCond.get) + val baseJoinConds = splitConjunctivePredicates(joinCond.get) + val nullAwareJoinConds = baseJoinConds.map(c => Or(c, IsNull(c))) // After that, add back the correlated join predicate(s) in the subquery // Example: // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) // will have the final conditions in the LEFT ANTI as - // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) - val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) + // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 + val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs))) + dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond))) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) @@ -149,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(value, Seq(ListQuery(sub, conditions, _, _))) => + case InSubquery(values, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdc357d54a878..7bc1f63e30540 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -503,18 +503,24 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val join = right.optionalMap(left)(Join(_, _, Inner, None)) withJoinRelations(join, relation) } - ctx.lateralView.asScala.foldLeft(from)(withGenerate) + if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } } /** * Connect two queries by a Set operator. * * Supported Set operators are: - * - UNION [DISTINCT] - * - UNION ALL - * - EXCEPT [DISTINCT] - * - MINUS [DISTINCT] - * - INTERSECT [DISTINCT] + * - UNION [ DISTINCT | ALL ] + * - EXCEPT [ DISTINCT | ALL ] + * - MINUS [ DISTINCT | ALL ] + * - INTERSECT [DISTINCT | ALL] */ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { val left = plan(ctx.left) @@ -526,17 +532,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.UNION => Distinct(Union(left, right)) case SqlBaseParser.INTERSECT if all => - throw new ParseException("INTERSECT ALL is not supported.", ctx) + Intersect(left, right, isAll = true) case SqlBaseParser.INTERSECT => - Intersect(left, right) + Intersect(left, right, isAll = false) case SqlBaseParser.EXCEPT if all => - throw new ParseException("EXCEPT ALL is not supported.", ctx) + Except(left, right, isAll = true) case SqlBaseParser.EXCEPT => - Except(left, right) + Except(left, right, isAll = false) case SqlBaseParser.SETMINUS if all => - throw new ParseException("MINUS ALL is not supported.", ctx) + Except(left, right, isAll = true) case SqlBaseParser.SETMINUS => - Except(left, right) + Except(left, right, isAll = false) } } @@ -614,6 +620,38 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging plan } + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText))) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) + Pivot(None, pivotColumn, pivotValues, aggregates, query) + } + + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ @@ -1065,6 +1103,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case not => Not(e) } + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + // Create the predicate. ctx.kind.getType match { case SqlBaseParser.BETWEEN => @@ -1073,7 +1116,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query))))) + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => @@ -1185,6 +1228,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging new StringLocate(expression(ctx.substr), expression(ctx.str)) } + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + ctx.field.getText.toUpperCase(Locale.ROOT) match { + case "YEAR" => + Year(expression(ctx.source)) + case "QUARTER" => + Quarter(expression(ctx.source)) + case "MONTH" => + Month(expression(ctx.source)) + case "WEEK" => + WeekOfYear(expression(ctx.source)) + case "DAY" => + DayOfMonth(expression(ctx.source)) + case "DAYOFWEEK" => + DayOfWeek(expression(ctx.source)) + case "HOUR" => + Hour(expression(ctx.source)) + case "MINUTE" => + Minute(expression(ctx.source)) + case "SECOND" => + Second(expression(ctx.source)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + /** * Create a (windowed) Function expression. */ @@ -1245,6 +1316,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.IDENTIFIER().asScala.map { name => + UnresolvedAttribute.quoted(name.getText) + } + LambdaFunction(expression(ctx.expression), arguments) + } + /** * Create a reference to a window frame, i.e. [[WindowSpecReference]]. */ @@ -1458,7 +1539,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case "TIMESTAMP" => Literal(Timestamp.valueOf(value)) case "X" => - val padding = if (value.length % 2 == 1) "0" else "" + val padding = if (value.length % 2 != 0) "0" else "" Literal(DatatypeConverter.parseHexBinary(padding + value)) case other => throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 4c20f2368bded..7d8cb1f18b4b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -84,12 +84,14 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) + lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) parser.addParseListener(PostProcessor) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) + parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced try { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 626f905707191..84be677e438a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -215,7 +216,7 @@ object PhysicalAggregation { case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg case udf: PythonUDF - if PythonUDF.isGroupAggPandasUDF(udf) && + if PythonUDF.isGroupedAggPandasUDF(udf) && !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -245,7 +246,7 @@ object PhysicalAggregation { equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression - case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) => equivalentAggregateExpressions.getEquivalentExprs(ue).headOption .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => @@ -268,3 +269,40 @@ object PhysicalAggregation { case _ => None } } + +/** + * An extractor used when planning physical execution of a window. This extractor outputs + * the window function type of the logical window. + * + * The input logical window must contain same type of window functions, which is ensured by + * the rule ExtractWindowExpressions in the analyzer. + */ +object PhysicalWindow { + // windowFunctionType, windowExpression, partitionSpec, orderSpec, child + private type ReturnType = + (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) => + + // The window expression should not be empty here, otherwise it's a bug. + if (windowExpressions.isEmpty) { + throw new AnalysisException(s"Window expression is empty in $expr") + } + + val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType) + .reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) => + if (t1 != t2) { + // We shouldn't have different window function type here, otherwise it's a bug. + throw new AnalysisException( + s"Found different window function type in $windowExpressions") + } else { + t1 + } + } + + Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child)) + + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 64cb8c726772f..b1ffdca091461 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -27,8 +27,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** * The active config object within the current scope. - * Note that if you want to refer config values during execution, you have to capture them - * in Driver and use the captured values in Executors. * See [[SQLConf.get]] for more information. */ def conf: SQLConf = SQLConf.get @@ -119,6 +117,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null @@ -285,7 +284,7 @@ object QueryPlan extends PredicateHelper { if (ordinal == -1) { ar } else { - ar.withExprId(ExprId(ordinal)) + ar.withExprId(ExprId(ordinal)).canonicalized } }.canonicalized.asInstanceOf[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala new file mode 100644 index 0000000000000..9404a809b453c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.CheckAnalysis +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.util.Utils + + +/** + * [[AnalysisHelper]] defines some infrastructure for the query analyzer. In particular, in query + * analysis we don't want to repeatedly re-analyze sub-plans that have previously been analyzed. + * + * This trait defines a flag `analyzed` that can be set to true once analysis is done on the tree. + * This also provides a set of resolve methods that do not recurse down to sub-plans that have the + * analyzed flag set to true. + * + * The analyzer rules should use the various resolve methods, in lieu of the various transform + * methods defined in [[TreeNode]] and [[QueryPlan]]. + * + * To prevent accidental use of the transform methods, this trait also overrides the transform + * methods to throw exceptions in test mode, if they are used in the analyzer. + */ +trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => + + private var _analyzed: Boolean = false + + /** + * Recursively marks all nodes in this plan tree as analyzed. + * This should only be called by [[CheckAnalysis]]. + */ + private[catalyst] def setAnalyzed(): Unit = { + if (!_analyzed) { + _analyzed = true + children.foreach(_.setAnalyzed()) + } + } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied to the tree. When + * `rule` does not apply to a given node, it is left unchanged. This function is similar to + * `transform`, but skips sub-trees that have already been marked as analyzed. + * Users should not expect a specific directionality. If a specific directionality is needed, + * [[resolveOperatorsUp]] or [[resolveOperatorsDown]] should be used. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + resolveOperatorsDown(rule) + } + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, + * it is left unchanged. This function is similar to `transformUp`, but skips sub-trees that + * have already been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperatorsUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRuleOnChildren = mapChildren(_.resolveOperatorsUp(rule)) + if (self fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } + } else { + self + } + } + + /** Similar to [[resolveOperatorsUp]], but does it top-down. */ + def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (self fastEquals afterRule) { + mapChildren(_.resolveOperatorsDown(rule)) + } else { + afterRule.mapChildren(_.resolveOperatorsDown(rule)) + } + } + } else { + self + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + resolveOperators { + case p => p.transformExpressions(r) + } + } + + protected def assertNotAnalysisRule(): Unit = { + if (Utils.isTesting && + AnalysisHelper.inAnalyzer.get > 0 && + AnalysisHelper.resolveOperatorDepth.get == 0) { + throw new RuntimeException("This method should not be called in the analyzer") + } + } + + /** + * In analyzer, use [[resolveOperatorsDown()]] instead. If this is used in the analyzer, + * an exception will be thrown in test mode. It is however OK to call this function within + * the scope of a [[resolveOperatorsDown()]] call. + * @see [[TreeNode.transformDown()]]. + */ + override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformDown(rule) + } + + /** + * Use [[resolveOperators()]] in the analyzer. + * @see [[TreeNode.transformUp()]] + */ + override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformUp(rule) + } + + /** + * Use [[resolveExpressions()]] in the analyzer. + * @see [[QueryPlan.transformAllExpressions()]] + */ + override def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + assertNotAnalysisRule() + super.transformAllExpressions(rule) + } + +} + + +object AnalysisHelper { + + /** + * A thread local to track whether we are in a resolveOperator call (for the purpose of analysis). + * This is an int because resolve* calls might be be nested (e.g. a rule might trigger another + * query compilation within the rule itself), so we are tracking the depth here. + */ + private val resolveOperatorDepth: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + /** + * A thread local to track whether we are in the analysis phase of query compilation. This is an + * int rather than a boolean in case our analyzer recursively calls itself. + */ + private val inAnalyzer: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + def allowInvokingTransformsInAnalyzer[T](f: => T): T = { + resolveOperatorDepth.set(resolveOperatorDepth.get + 1) + try f finally { + resolveOperatorDepth.set(resolveOperatorDepth.get - 1) + } + } + + def markInAnalyzer[T](f: => T): T = { + inAnalyzer.set(inAnalyzer.get + 1) + try f finally { + inAnalyzer.set(inAnalyzer.get - 1) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b05508db786ad..8c4828a4cef23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -43,6 +44,12 @@ object LocalRelation { } } +/** + * Logical plan node for scanning data from a local collection. + * + * @param data The local collection holding the data. It doesn't need to be sent to executors + * and then doesn't need to be serializable. + */ case class LocalRelation( output: Seq[Attribute], data: Seq[InternalRow] = Nil, @@ -71,7 +78,7 @@ case class LocalRelation( } override def computeStats(): Statistics = - Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 42034403d6d03..0e4456ac0e6a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,12 +23,12 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats -import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType abstract class LogicalPlan extends QueryPlan[LogicalPlan] + with AnalysisHelper with LogicalPlanStats with QueryPlanConstraints with Logging { @@ -78,7 +78,7 @@ abstract class LogicalPlan schema.map { field => resolve(field.name :: Nil, resolver).map { case a: AttributeReference => a - case other => sys.error(s"can not handle nested schema yet... plan $this") + case _ => sys.error(s"can not handle nested schema yet... plan $this") }.getOrElse { throw new AnalysisException( s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") @@ -86,6 +86,10 @@ abstract class LogicalPlan } } + private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output)) + + private[this] lazy val outputAttributes = AttributeSeq(output) + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as @@ -94,7 +98,7 @@ abstract class LogicalPlan def resolveChildren( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver) + childAttributes.resolve(nameParts, resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -104,7 +108,7 @@ abstract class LogicalPlan def resolve( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, output, resolver) + outputAttributes.resolve(nameParts, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -114,105 +118,7 @@ abstract class LogicalPlan def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * This assumes `name` has multiple parts, where the 1st part is a qualifier - * (i.e. table name, alias, or subquery alias). - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsTableColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - assert(nameParts.length > 1) - if (attribute.qualifier.exists(resolver(_, nameParts.head))) { - // At least one qualifier matches. See if remaining parts match. - val remainingParts = nameParts.tail - resolveAsColumn(remainingParts, resolver, attribute) - } else { - None - } - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) - } else { - None - } - } - - /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve( - nameParts: Seq[String], - input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - // A sequence of possible candidate matches. - // Each candidate is a tuple. The first element is a resolved attribute, followed by a list - // of parts that are to be resolved. - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - var candidates: Seq[(Attribute, List[String])] = { - // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (nameParts.length > 1) { - input.flatMap { option => - resolveAsTableColumn(nameParts, resolver, option) - } - } else { - Seq.empty - } - } - - // If none of attributes match `table.column` pattern, we try to resolve it as a column. - if (candidates.isEmpty) { - candidates = input.flatMap { candidate => - resolveAsColumn(nameParts, resolver, candidate) - } - } - - def name = UnresolvedAttribute(nameParts).name - - candidates.distinct match { - // One match, no nested fields, use it. - case Seq((a, Nil)) => Some(a) - - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliased it with the last part of the name. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final - // expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - Some(Alias(fieldExprs, nestedFields.last)()) - - // No matches. - case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") - None - - // More than one match. - case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") - } + outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index a29f3d29236c7..cc352c59dff80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,29 +20,28 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains an additional set of constraints, such as equality - * constraints and `isNotNull` constraints, etc. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val allConstraints: ExpressionSet = { + lazy val constraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints))) + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints, output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } } - /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. - */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) - /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then @@ -52,30 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan => * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty +} + +trait ConstraintHelper { /** - * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as - * equality constraints and `isNotNull` constraints, etc., and that only contains references - * to this [[LogicalPlan]] node. + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. */ - def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { - val allRelevantConstraints = - if (conf.constraintPropagationEnabled) { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - } else { - constraints - } - ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case _ => // No inference + } + inferredConstraints -- constraints } + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { + case e: Expression if e.semanticEquals(source) => destination + }) + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this * returns a constraint of the form `isNotNull(a)` */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + def constructIsNotNullConstraints( + constraints: Set[Expression], + output: Seq[Attribute]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) @@ -111,32 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan => case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - var inferredConstraints = Set.empty[Expression] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case _ => // No inference - } - inferredConstraints -- constraints - } - - private def replaceConstraints( - constraints: Set[Expression], - source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { - case e: Expression if e.semanticEquals(source) => destination - }) - - private def selfReferenceOnly(e: Expression): Boolean = { - e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 10df504795430..7ff83a9be3622 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.{AliasIdentifier} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, - RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -74,7 +74,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) * their output. * * @param generator the generator expression - * @param unrequiredChildIndex this paramter starts as Nil and gets filled by the Optimizer. + * @param unrequiredChildIndex this parameter starts as Nil and gets filled by the Optimizer. * It's used as an optimization for omitting data generation that will * be discarded next by a projection. * A common use case is when we explode(array(..)) and are interested @@ -113,7 +113,7 @@ case class Generate( def qualifiedGeneratorOutput: Seq[Attribute] = { val qualifiedOutput = qualifier.map { q => // prepend the new qualifier to the existed one - generatorOutput.map(a => a.withQualifier(Some(q))) + generatorOutput.map(a => a.withQualifier(Seq(q))) }.getOrElse(generatorOutput) val nullableOutput = qualifiedOutput.map { // if outer, make all attributes nullable, otherwise keep existing nullability @@ -164,7 +164,12 @@ object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { +case class Intersect( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean) extends SetOperation(left, right) { + + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => @@ -183,8 +188,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - +case class Except( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean) extends SetOperation(left, right) { + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output @@ -344,6 +352,38 @@ case class Join( } } +/** + * Append data to an existing table. + */ +case class AppendData( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Seq.empty + + override lazy val resolved: Boolean = { + table.resolved && query.resolved && query.output.size == table.output.size && + query.output.zip(table.output).forall { + case (inAttr, outAttr) => + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && + (outAttr.nullable || !inAttr.nullable) + } + } +} + +object AppendData { + def byName(table: NamedRelation, df: LogicalPlan): AppendData = { + new AppendData(table, df, true) + } + + def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { + new AppendData(table, query, false) + } +} + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. @@ -686,17 +726,34 @@ case class GroupingSets( override lazy val resolved: Boolean = false } +/** + * A constructor for creating a pivot, which will later be converted to a [[Project]] + * or an [[Aggregate]] during the query analysis. + * + * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming + * from SQL, in which group by expressions are not explicitly specified. + * @param pivotColumn The pivot column. + * @param pivotValues A sequence of values for the pivot column. + * @param aggregates The aggregation expressions, each with or without an alias. + * @param child Child operator + */ case class Pivot( - groupByExprs: Seq[NamedExpression], + groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, - pivotValues: Seq[Literal], + pivotValues: Seq[Expression], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) - case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + override lazy val resolved = false // Pivot will be replaced after being resolved. + override def output: Seq[Attribute] = { + val pivotAgg = aggregates match { + case agg :: Nil => + pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => + pivotValues.flatMap { value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } } + groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } } @@ -769,19 +826,37 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr /** * Aliased subquery. * - * @param alias the alias name for this subquery. + * @param name the alias identifier for this subquery. * @param child the logical plan of this subquery. */ case class SubqueryAlias( - alias: String, + name: AliasIdentifier, child: LogicalPlan) extends OrderPreservingUnaryNode { - override def doCanonicalize(): LogicalPlan = child.canonicalized + def alias: String = name.identifier - override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) + override def output: Seq[Attribute] = { + val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias)) + child.output.map(_.withQualifier(qualifierList)) + } + override def doCanonicalize(): LogicalPlan = child.canonicalized } +object SubqueryAlias { + def apply( + identifier: String, + child: LogicalPlan): SubqueryAlias = { + SubqueryAlias(AliasIdentifier(identifier), child) + } + + def apply( + identifier: String, + database: String, + child: LogicalPlan): SubqueryAlias = { + SubqueryAlias(AliasIdentifier(identifier, Some(database)), child) + } +} /** * Sample the dataset. * @@ -899,23 +974,3 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } - -/** - * A logical plan for setting a barrier of analysis. - * - * The SQL Analyzer goes through a whole query plan even most part of it is analyzed. This - * increases the time spent on query analysis for long pipelines in ML, especially. - * - * This logical plan wraps an analyzed logical plan to prevent it from analysis again. The barrier - * is applied to the analyzed logical plan in Dataset. It won't change the output of wrapped - * logical plan and just acts as a wrapper to hide it from analyzer. New operations on the dataset - * will be put on the barrier, so only the new nodes created will be analyzed. - * - * This analysis barrier will be removed at the end of analysis stage. - */ -case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { - override protected def innerChildren: Seq[LogicalPlan] = Seq(child) - override def output: Seq[Attribute] = child.output - override def isStreaming: Boolean = child.isStreaming - override def doCanonicalize(): LogicalPlan = child.canonicalized -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 0f147f0ffb135..211a2a0717371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} - object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ @@ -73,13 +71,12 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } - def getOutputSize( + def getSizePerRow( attributes: Seq[Attribute], - outputRowCount: BigInt, attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. - val sizePerRow = 8 + attributes.map { attr => + 8 + attributes.map { attr => if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => @@ -92,10 +89,15 @@ object EstimationUtils { attr.dataType.defaultSize } }.sum + } + def getOutputSize( + attributes: Seq[Attribute], + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero // (simple computation of statistics returns product of children). - if (outputRowCount > 0) outputRowCount * sizePerRow else 1 + if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1 } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0538c9d88584b..5a3eeefaedb18 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -395,6 +395,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + val statsInterval = ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => @@ -418,6 +422,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => + if (ndv.toDouble == 0) { + return Some(0.0) + } + newNdv = ndv.min(BigInt(hSet.size)) if (update) { val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 85f67c7d66075..ee43f9126386b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { private def visitUnaryNode(p: UnaryNode): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. - val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + val childRowSize = EstimationUtils.getSizePerRow(p.child.output) + val outputRowSize = EstimationUtils.getSizePerRow(p.output) // Assume there will be the same number of rows as child has. var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala index f46b4ed764e27..693d2a7210ab8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala @@ -69,6 +69,8 @@ object ValueInterval { false case (n1: NumericValueInterval, n2: NumericValueInterval) => n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 + case _ => + throw new UnsupportedOperationException(s"Not supported pair: $r1, $r2 at isIntersected()") } /** @@ -86,6 +88,8 @@ object ValueInterval { val newMax = if (n1.max <= n2.max) n1.max else n2.max (Some(EstimationUtils.fromDouble(newMin, dt)), Some(EstimationUtils.fromDouble(newMax, dt))) + case _ => + throw new UnsupportedOperationException(s"Not supported pair: $r1, $r2 at intersect()") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4d9a9925fe3ff..cd28c733f3613 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -99,16 +101,19 @@ case class ClusteredDistribution( * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ -case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { +case class HashClusteredDistribution( + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { require( expressions != Nil, - "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "The expressions for hash of a HashClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - override def requiredNumPartitions: Option[Int] = None - override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") HashPartitioning(expressions, numPartitions) } } @@ -163,11 +168,22 @@ trait Partitioning { * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). * + * A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` does't match + * [[Distribution.requiredNumPartitions]]. + */ + final def satisfies(required: Distribution): Boolean = { + required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) + } + + /** + * The actual method that defines whether this [[Partitioning]] can satisfy the given + * [[Distribution]], after the `numPartitions` check. + * * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if - * the [[Partitioning]] only have one partition. Implementations can overwrite this method with - * special logic. + * the [[Partitioning]] only have one partition. Implementations can also overwrite this method + * with special logic. */ - def satisfies(required: Distribution): Boolean = required match { + protected def satisfies0(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case AllTuples => numPartitions == 1 case _ => false @@ -186,13 +202,24 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } } +/** + * Represents a partitioning where rows are only serialized/deserialized locally. The number + * of partitions are not changed and also the distribution of rows. This is mainly used to + * obtain some statistics of map tasks such as number of outputs. + */ +case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning { + val numPartitions = childRDD.getNumPartitions + + // We will perform this partitioning no matter what the data distribution is. + override def satisfies0(required: Distribution): Boolean = false +} + /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be @@ -205,16 +232,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case h: HashClusteredDistribution => expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -246,15 +272,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -295,7 +320,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Returns true if any `partitioning` of this collection satisfies the given * [[Distribution]]. */ - override def satisfies(required: Distribution): Boolean = + override def satisfies0(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) override def toString: String = { @@ -310,7 +335,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { override val numPartitions: Int = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case BroadcastDistribution(m) if m == mode => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9c7d47f99ee10..becfa8d982213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer + def mapTreeNode(node: TreeNode[_]): TreeNode[_] = { + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case nonChild: AnyRef => nonChild + case null => null + } val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: Seq[_] => s.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - } - case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - }.view.force // `mapValues` is lazy and we need to force it to materialize - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } + case s: Stream[_] => + // Stream is lazy so we need to force materialization + s.map(mapChild).force + case s: Seq[_] => + s.map(mapChild) + case m: Map[_, _] => + // `mapValues` is lazy and we need to force it to materialize + m.mapValues(mapChild).view.force + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) case nonChild: AnyRef => nonChild case null => null } @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { var changed = false + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = if (containsChild(arg1)) { - f(arg1.asInstanceOf[BaseType]) - } else { - arg1.asInstanceOf[BaseType] - } - - val newChild2 = if (containsChild(arg2)) { - f(arg2.asInstanceOf[BaseType]) - } else { - arg2.asInstanceOf[BaseType] - } - - if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { - changed = true - (newChild1, newChild2) - } else { - tuple - } - case other => other - } + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Traversable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index fa69b8af62c85..02813d3939796 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -96,9 +96,9 @@ object DateTimeUtils { } } - def getThreadLocalDateFormat(): DateFormat = { + def getThreadLocalDateFormat(timeZone: TimeZone): DateFormat = { val sdf = threadLocalDateFormat.get() - sdf.setTimeZone(defaultTimeZone()) + sdf.setTimeZone(timeZone) sdf } @@ -144,7 +144,11 @@ object DateTimeUtils { } def dateToString(days: SQLDate): String = - getThreadLocalDateFormat.format(toJavaDate(days)) + getThreadLocalDateFormat(defaultTimeZone()).format(toJavaDate(days)) + + def dateToString(days: SQLDate, timeZone: TimeZone): String = { + getThreadLocalDateFormat(timeZone).format(toJavaDate(days)) + } // Converts Timestamp to string according to Hive TimestampWritable convention. def timestampToString(us: SQLTimestamp): String = { @@ -296,10 +300,28 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { - stringToTimestamp(s, defaultTimeZone()) + stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false) } def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { + stringToTimestamp(s, timeZone, rejectTzInString = false) + } + + /** + * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone. + * Returns None if the input string is not a valid timestamp format. + * + * @param s the input timestamp string. + * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string + * already contains timezone information and `forceTimezone` is false. + * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the + * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`, + * return None. + */ + def stringToTimestamp( + s: UTF8String, + timeZone: TimeZone, + rejectTzInString: Boolean): Option[SQLTimestamp] = { if (s == null) { return None } @@ -417,6 +439,8 @@ object DateTimeUtils { return None } + if (tz.isDefined && rejectTzInString) return None + val c = if (tz.isEmpty) { Calendar.getInstance(timeZone) } else { @@ -865,29 +889,19 @@ object DateTimeUtils { /** * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. + * microseconds since 1.1.1970. If time1 is later than time2, the result is positive. * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). + * If time1 and time2 are on the same day of month, or both are the last day of month, + * returns, time of day will be ignored. * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. + * Otherwise, the difference is calculated based on 31 days per month. + * The result is rounded to 8 decimal places if `roundOff` is set to true. */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { - monthsBetween(time1, time2, defaultTimeZone()) - } - - /** - * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. - * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). - * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. - */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = { + def monthsBetween( + time1: SQLTimestamp, + time2: SQLTimestamp, + roundOff: Boolean, + timeZone: TimeZone): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L val date1 = millisToDays(millis1, timeZone) @@ -898,16 +912,25 @@ object DateTimeUtils { val months1 = year1 * 12 + monthInYear1 val months2 = year2 * 12 + monthInYear2 + val monthDiff = (months1 - months2).toDouble + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { - return (months1 - months2).toDouble + return monthDiff + } + // using milliseconds can cause precision loss with more than 8 digits + // we follow Hive's implementation which uses seconds + val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L + val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L + val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 + // 2678400D is the number of seconds in 31 days + // every month is considered to be 31 days long in this function + val diff = monthDiff + secondsDiff / 2678400D + if (roundOff) { + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } else { + diff } - // milliseconds is enough for 8 digits precision on the right side - val timeInDay1 = millis1 - daysToMillis(date1, timeZone) - val timeInDay2 = millis2 - daysToMillis(date2, timeZone) - val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY - val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 - // rounding to 8 digits - math.round(diff * 1e8) / 1e8 } // Thursday = 0 since 1970/Jan/01 => Thursday diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 9e39ed9c3a778..83ad08d8e1758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -122,7 +122,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } - case _ => if (o1 != o2) { + case _ => if (!o1.equals(o2)) { return false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index b013add9c9778..3190e511e2cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -40,12 +40,14 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * See the G-K article for more details. * @param count the count of all the elements *inserted in the sampled buffer* * (excluding the head buffer) + * @param compressed whether the statistics have been compressed */ class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, val sampled: Array[Stats] = Array.empty, - val count: Long = 0L) extends Serializable { + val count: Long = 0L, + var compressed: Boolean = false) extends Serializable { // a buffer of latest samples seen so far private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty @@ -60,6 +62,7 @@ class QuantileSummaries( */ def insert(x: Double): QuantileSummaries = { headSampled += x + compressed = false if (headSampled.size >= defaultHeadSize) { val result = this.withHeadBufferInserted if (result.sampled.length >= compressThreshold) { @@ -135,11 +138,11 @@ class QuantileSummaries( assert(inserted.count == count + headSampled.size) val compressed = compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) - new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true) } private def shallowCopy: QuantileSummaries = { - new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new QuantileSummaries(compressThreshold, relativeError, sampled, count, compressed) } /** @@ -163,7 +166,7 @@ class QuantileSummaries( val res = (sampled ++ other.sampled).sortBy(_.value) val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) new QuantileSummaries( - other.compressThreshold, other.relativeError, comp, other.count + count) + other.compressThreshold, other.relativeError, comp, other.count + count, true) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala new file mode 100644 index 0000000000000..ae05128f94777 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.commons.math3.random.MersenneTwister + +/** + * This class is used to generate a random indices of given length. + * + * This implementation uses the "inside-out" version of Fisher-Yates algorithm. + * Reference: + * https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_%22inside-out%22_algorithm + */ +case class RandomIndicesGenerator(randomSeed: Long) { + private val random = new MersenneTwister(randomSeed) + + def getNextIndices(length: Int): Array[Int] = { + val indices = new Array[Int](length) + var i = 0 + while (i < length) { + val j = random.nextInt(i + 1) + if (j != i) { + indices(i) = indices(j) + } + indices(j) = i + i += 1 + } + indices + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1dcda49a3af6a..76218b459ef0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -17,19 +17,19 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ /** - * Helper functions to check for valid data types. + * Functions to help with checking for valid data types and value comparison of various types. */ object TypeUtils { def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = { if (dt.isInstanceOf[NumericType] || dt == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt") + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not ${dt.catalogString}") } } @@ -37,23 +37,18 @@ object TypeUtils { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") + TypeCheckResult.TypeCheckFailure( + s"$caller does not support ordering on type ${dt.catalogString}") } } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { - if (types.size <= 1) { + if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - val firstType = types.head - types.foreach { t => - if (!t.sameType(firstType)) { - return TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) - } - } - TypeCheckResult.TypeCheckSuccess + return TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's " + + types.map(_.catalogString).mkString("[", ", ", "]")) } } @@ -78,4 +73,15 @@ object TypeUtils { } x.length - y.length } + + /** + * Returns true if the equals method of the elements of the data type is implemented properly. + * This also means that they can be safely used in collections relying on the equals method, + * as sets or maps. + */ + def typeWithProperEquals(dataType: DataType): Boolean = dataType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 4005087dad05a..0978e92dd4f72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -155,6 +155,18 @@ package object util { def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql + + def escapeSingleQuotedString(str: String): String = { + val builder = StringBuilder.newBuilder + + str.foreach { + case '\'' => builder ++= s"\\\'" + case ch => builder += ch + } + + builder.toString() + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3729bd5293eca..738d8fee891d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -20,18 +20,21 @@ package org.apache.spark.sql.internal import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference +import java.util.zip.Deflater import scala.collection.JavaConverters._ import scala.collection.immutable import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.tukaani.xz.LZMA2Options -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.Utils @@ -79,6 +82,19 @@ object SQLConf { /** See [[get]] for more information. */ def getFallbackConf: SQLConf = fallbackConf.get() + private lazy val existingConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = null + } + + def withExistingConf[T](conf: SQLConf)(f: => T): T = { + existingConf.set(conf) + try { + f + } finally { + existingConf.remove() + } + } + /** * Defines a getter that returns the SQLConf within scope. * See [[get]] for more information. @@ -95,7 +111,9 @@ object SQLConf { /** * Returns the active config object within the current scope. If there is an active SparkSession, - * the proper SQLConf associated with the thread's session is used. + * the proper SQLConf associated with the thread's active session is used. If it's called from + * tasks in the executor side, a SQLConf will be created from job local properties, which are set + * and propagated from the driver side. * * The way this works is a little bit convoluted, due to the fact that config was added initially * only for physical plans (and as a result not in sql/catalyst module). @@ -107,7 +125,38 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + val isSchedulerEventLoopThread = SparkContext.getActive + .map(_.dagScheduler.eventProcessLoop.eventThread) + .exists(_.getId == Thread.currentThread().getId) + if (isSchedulerEventLoopThread) { + // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` + // will return `fallbackConf` which is unexpected. Here we require the caller to get the + // conf within `withExistingConf`, otherwise fail the query. + val conf = existingConf.get() + if (conf != null) { + conf + } else if (Utils.isTesting) { + throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } else { + confGetter.get()() + } + } else { + confGetter.get()() + } + } + } + + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") + .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + + "specified by their rule names and separated by comma. It is not guaranteed that all the " + + "rules in this configuration will eventually be excluded, as some rules are necessary " + + "for correctness. The optimizer will log the rules that have indeed been excluded.") + .stringConf + .createOptional val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -186,6 +235,13 @@ object SQLConf { .intConf .createWithDefault(4) + val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit") + .internal() + .doc("During global limit, try to evenly distribute limited rows across data " + + "partitions. If disabled, scanning data partitions sequentially until reaching limit number.") + .booleanConf + .createWithDefault(true) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") .internal() @@ -342,10 +398,10 @@ object SQLConf { "`parquet.compression` is specified in the table-specific options/properties, the " + "precedence would be `compression`, `parquet.compression`, " + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + - "snappy, gzip, lzo.") + "snappy, gzip, lzo, brotli, lz4, zstd.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -360,6 +416,43 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.timestamp") + .doc("If true, enables Parquet filter push-down optimization for Timestamp. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is " + + "enabled and Timestamp stored as TIMESTAMP_MICROS or TIMESTAMP_MILLIS type.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.decimal") + .doc("If true, enables Parquet filter push-down optimization for Decimal. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.string.startsWith") + .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD = + buildConf("spark.sql.parquet.pushdown.inFilterThreshold") + .doc("The maximum number of values to filter push-down optimization for IN predicate. " + + "Large threshold won't necessarily provide much better performance. " + + "The experiment argued that 300 is the limit threshold. " + + "By setting this value to 0 this feature can be disabled. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .intConf + .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") + .createWithDefault(10) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -377,7 +470,7 @@ object SQLConf { .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + - "will never be created, irrespective of the value of parquet.enable.summary-metadata") + "will never be created, irrespective of the value of parquet.summary.metadata.level") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") @@ -581,6 +674,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") + .doc("The maximum number of buckets allowed. Defaults to 100000") + .intConf + .checkValue(_ > 0, "the value of spark.sql.sources.bucketing.maxBuckets must be larger than 0") + .createWithDefault(100000) + val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled") .doc("When false, we will throw an error if a query contains a cartesian product without " + "explicit CROSS JOIN syntax.") @@ -686,6 +785,17 @@ object SQLConf { .intConf .createWithDefault(100) + val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode") + .doc("This config determines the fallback behavior of several codegen generators " + + "during tests. `FALLBACK` means trying codegen first and then fallbacking to " + + "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " + + "`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " + + "this config works only for tests.") + .internal() + .stringConf + .checkValues(CodegenObjectFactoryMode.values.map(_.toString)) + .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString) + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + @@ -777,6 +887,14 @@ object SQLConf { .intConf .createWithDefault(10) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") + .internal() + .doc("State format version used by flatMapGroupsWithState operation in a streaming query") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation") .doc("The default location for storing checkpoint data for streaming queries.") .stringConf @@ -788,6 +906,25 @@ object SQLConf { .intConf .createWithDefault(100) + val MAX_BATCHES_TO_RETAIN_IN_MEMORY = buildConf("spark.sql.streaming.maxBatchesToRetainInMemory") + .internal() + .doc("The maximum number of batches which will be retained in memory to avoid " + + "loading from files. The value adjusts a trade-off between memory usage vs cache miss: " + + "'2' covers both success and direct failure cases, '1' covers only success case, " + + "and '0' covers extreme case - disable cache to maximize memory size of executors.") + .intConf + .createWithDefault(2) + + val STREAMING_AGGREGATION_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.aggregation.stateFormatVersion") + .internal() + .doc("State format version used by streaming aggregation operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -838,6 +975,21 @@ object SQLConf { .stringConf .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") + val STREAMING_MULTIPLE_WATERMARK_POLICY = + buildConf("spark.sql.streaming.multipleWatermarkPolicy") + .doc("Policy to calculate the global watermark value when there are multiple watermark " + + "operators in a streaming query. The default value is 'min' which chooses " + + "the minimum watermark reported across multiple operators. Other alternative value is" + + "'max' which chooses the maximum across multiple operators." + + "Note: This configuration cannot be changed between query restarts from the same " + + "checkpoint location.") + .stringConf + .checkValue( + str => Set("min", "max").contains(str.toLowerCase), + "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " + + "Valid values are 'min' and 'max'") + .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") .internal() @@ -919,6 +1071,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10000L) + val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = + buildConf("spark.sql.streaming.noDataMicroBatchesEnabled") + .doc( + "Whether streaming micro-batch engine will execute batches without data " + + "for eager state management for stateful streaming queries.") + .booleanConf + .createWithDefault(true) + val STREAMING_METRICS_ENABLED = buildConf("spark.sql.streaming.metricsEnabled") .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") @@ -1124,6 +1284,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION = + buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition") + .internal() + .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + + "Pandas DataFrame based on position, regardless of column label type. When false, " + + "columns will be looked up by name if labeled with a string and fallback to use " + + "position if not. This configuration will be deprecated in future releases.") + .booleanConf + .createWithDefault(false) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -1147,8 +1317,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SQL_OPTIONS_REDACTION_PATTERN = + buildConf("spark.sql.redaction.options.regex") + .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + + "information. The values of options whose names that match this regex will be redacted " + + "in the explain output. This redaction is applied on top of the global redaction " + + s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.") + .regexConf + .createWithDefault("(?i)url".r) + val SQL_STRING_REDACTION_PATTERN = - ConfigBuilder("spark.sql.redaction.string.regex") + buildConf("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + "information. When this regex matches a string part, that string part is replaced by a " + "dummy value. This is currently used to redact the output of SQL explain commands. " + @@ -1208,6 +1387,13 @@ object SQLConf { .stringConf .createWithDefault("") + val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString") + .internal() + .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " + + "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.") + .booleanConf + .createWithDefault(true) + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1220,7 +1406,11 @@ object SQLConf { "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + "those partitions that have data written into it at runtime. By default we use static " + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + - "affect Hive serde tables, as they are always overwritten with dynamic mode.") + "affect Hive serde tables, as they are always overwritten with dynamic mode. This can " + + "also be set as an output option for a data source using key partitionOverwriteMode " + + "(which takes precedence over this setting), e.g. " + + "dataframe.write.option(\"partitionOverwriteMode\", \"dynamic\").save(path)." + ) .stringConf .transform(_.toUpperCase(Locale.ROOT)) .checkValues(PartitionOverwriteMode.values.map(_.toString)) @@ -1235,8 +1425,27 @@ object SQLConf { "issues. Turn on this config to insert a local sort before actually doing repartition " + "to generate consistent repartition results. The performance of repartition() may go " + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + + val NESTED_SCHEMA_PRUNING_ENABLED = + buildConf("spark.sql.nestedSchemaPruning.enabled") + .internal() + .doc("Prune nested fields from a logical relation's output which are unnecessary in " + + "satisfying a query. This optimization allows columnar file format readers to avoid " + + "reading unnecessary nested column data. Currently Parquet is the only data source that " + + "implements this optimization.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) + + val TOP_K_SORT_FALLBACK_THRESHOLD = + buildConf("spark.sql.execution.topKSortFallbackThreshold") + .internal() + .doc("In SQL queries with a SORT followed by a LIMIT like " + + "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" + + " in memory, otherwise do a global sort which spills to disk if necessary.") + .intConf + .createWithDefault(Int.MaxValue) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -1245,6 +1454,94 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) + + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + + "The size function returns null for null input if the flag is disabled.") + .booleanConf + .createWithDefault(true) + + val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") + .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + + "eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " + + "the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " + + "the returned outputs are formatted like dataframe.show().") + .booleanConf + .createWithDefault(false) + + val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows") + .doc("The max number of rows that are returned by eager evaluation. This only takes " + + "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " + + "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " + + "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).") + .intConf + .createWithDefault(20) + + val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate") + .doc("The max number of characters for each cell that is returned by eager evaluation. " + + "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") + .intConf + .createWithDefault(20) + + val FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT = + buildConf("spark.sql.codegen.aggregate.fastHashMap.capacityBit") + .internal() + .doc("Capacity for the max number of rows to be held in memory " + + "by the fast hash aggregate product operator. The bit is not for actual value, " + + "but the actual numBuckets is determined by loadFactor " + + "(e.g: default bit value 16 , the actual numBuckets is ((1 << 16) / 0.5).") + .intConf + .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") + .createWithDefault(16) + + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") + .doc("Compression codec used in writing of AVRO files. Supported codecs: " + + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") + .stringConf + .checkValues(Set("uncompressed", "deflate", "snappy", "bzip2", "xz")) + .createWithDefault("snappy") + + val AVRO_DEFLATE_LEVEL = buildConf("spark.sql.avro.deflate.level") + .doc("Compression level for the deflate codec used in writing of AVRO files. " + + "Valid value must be in the range of from 1 to 9 inclusive or -1. " + + "The default value is -1 which corresponds to 6 level in the current implementation.") + .intConf + .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) + .createWithDefault(Deflater.DEFAULT_COMPRESSION) + + val LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED = + buildConf("spark.sql.legacy.replaceDatabricksSparkAvro.enabled") + .doc("If it is set to true, the data source provider com.databricks.spark.avro is mapped " + + "to the built-in but external Avro data source module for backward compatibility.") + .booleanConf + .createWithDefault(true) + + val LEGACY_SETOPS_PRECEDENCE_ENABLED = + buildConf("spark.sql.legacy.setopsPrecedence.enabled") + .internal() + .doc("When set to true and the order of evaluation is not specified by parentheses, the " + + "set operations are performed from left to right as they appear in the query. When set " + + "to false and order of evaluation is not specified by parentheses, INTERSECT operations " + + "are performed before any UNION, EXCEPT and MINUS operations.") + .booleanConf + .createWithDefault(false) + + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = + buildConf("spark.sql.parallelFileListingInStatsComputation.enabled") + .internal() + .doc("When true, SQL commands use parallel file listing, " + + "as opposed to single thread listing." + + "This usually speeds up commands that need to list many directories.") + .booleanConf + .createWithDefault(true) } /** @@ -1259,20 +1556,16 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ + def optimizerExcludedRules: Option[String] = getConf(OPTIMIZER_EXCLUDED_RULES) + def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) @@ -1306,6 +1599,9 @@ class SQLConf extends Serializable with Logging { def streamingNoDataProgressEventInterval: Long = getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + def streamingNoDataMicroBatchesEnabled: Boolean = + getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) @@ -1350,10 +1646,22 @@ class SQLConf extends Serializable with Logging { def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) + def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY) + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + + def parquetFilterPushDownDecimal: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED) + + def parquetFilterPushDownStringStartWith: Boolean = + getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + + def parquetFilterPushDownInFilterThreshold: Int = + getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) @@ -1382,6 +1690,8 @@ class SQLConf extends Serializable with Logging { def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) + def codegenComments: Boolean = getConf(StaticSQLConf.CODEGEN_COMMENTS) + def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) @@ -1392,6 +1702,8 @@ class SQLConf extends Serializable with Logging { def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) + def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) @@ -1402,10 +1714,14 @@ class SQLConf extends Serializable with Logging { def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) - def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + + def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. @@ -1425,6 +1741,8 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT) + def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) @@ -1491,6 +1809,8 @@ class SQLConf extends Serializable with Logging { def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def bucketingMaxBuckets: Int = getConf(SQLConf.BUCKETING_MAX_BUCKETS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) @@ -1579,6 +1899,9 @@ class SQLConf extends Serializable with Logging { def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) + def pandasGroupedMapAssignColumnssByPosition: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) @@ -1603,6 +1926,30 @@ class SQLConf extends Serializable with Logging { def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) + + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + + def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + + def isReplEagerEvalEnabled: Boolean = getConf(SQLConf.REPL_EAGER_EVAL_ENABLED) + + def replEagerEvalMaxNumRows: Int = getConf(SQLConf.REPL_EAGER_EVAL_MAX_NUM_ROWS) + + def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + + def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC) + + def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) + + def replaceDatabricksSparkAvroEnabled: Boolean = + getConf(SQLConf.LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED) + + def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) + + def parallelFileListingInStatsComputation: Boolean = + getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ @@ -1709,6 +2056,17 @@ class SQLConf extends Serializable with Logging { }.toSeq } + /** + * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN. + */ + def redactOptions(options: Map[String, String]): Map[String, String] = { + val regexes = Seq( + getConf(SQL_OPTIONS_REDACTION_PATTERN), + SECRET_REDACTION_PATTERN.readFrom(reader)) + + regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap + } + /** * Return whether a given key is set in this [[SQLConf]]. */ @@ -1716,7 +2074,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } @@ -1748,4 +2106,8 @@ class SQLConf extends Serializable with Logging { } cloned } + + def isModifiable(key: String): Boolean = { + sqlConfEntries.containsKey(key) && !staticConfKeys.contains(key) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index fe0ad39c29025..d9c354b165e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -66,6 +66,22 @@ object StaticSQLConf { .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") .createWithDefault(1000) + val CODEGEN_CACHE_MAX_ENTRIES = buildStaticConf("spark.sql.codegen.cache.maxEntries") + .internal() + .doc("When nonzero, enable caching of generated classes for operators and expressions. " + + "All jobs share the cache that can use up to the specified number for generated classes.") + .intConf + .checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative") + .createWithDefault(100) + + val CODEGEN_COMMENTS = buildStaticConf("spark.sql.codegen.comments") + .internal() + .doc("When true, put comment in the generated code. Since computing huge comments " + + "can be extremely expensive in certain cases, such as deeply-nested expressions which " + + "operate over inputs with wide schemas, default is false.") + .booleanConf + .createWithDefault(false) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildStaticConf("spark.sql.debug") @@ -96,6 +112,14 @@ object StaticSQLConf { .toSequence .createOptional + val STREAMING_QUERY_LISTENERS = buildStaticConf("spark.sql.streaming.streamingQueryListeners") + .doc("List of class names implementing StreamingQueryListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional + val UI_RETAINED_EXECUTIONS = buildStaticConf("spark.sql.ui.retainedExecutions") .doc("Number of executions to retain in the Spark UI.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 3041f44b116ea..c43cc748655e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -145,7 +145,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType extends AbstractDataType { +private[spark] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -155,11 +155,12 @@ private[sql] object NumericType extends AbstractDataType { */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - override private[sql] def defaultConcreteType: DataType = DoubleType + override private[spark] def defaultConcreteType: DataType = DoubleType - override private[sql] def simpleString: String = "numeric" + override private[spark] def simpleString: String = "numeric" - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] + override private[spark] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 38c40482fa4d9..58c75b5dc7a35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -42,7 +42,7 @@ object ArrayType extends AbstractDataType { other.isInstanceOf[ArrayType] } - override private[sql] def simpleString: String = "array" + override private[spark] def simpleString: String = "array" } /** @@ -103,7 +103,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => - throw new IllegalArgumentException(s"Type $other does not support ordered operations") + throw new IllegalArgumentException( + s"Type ${other.catalogString} does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..e53628d11ccf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -19,13 +19,17 @@ package org.apache.spark.sql.types import java.util.Locale +import scala.util.control.NonFatal + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -110,6 +114,16 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r + + def fromDDL(ddl: String): DataType = { + try { + CatalystSqlParser.parseDataType(ddl) + } catch { + case NonFatal(_) => CatalystSqlParser.parseTableSchema(ddl) + } + } + def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { @@ -120,7 +134,6 @@ object DataType { /** Given the string representation of a type, return its DataType */ private def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) @@ -325,4 +338,124 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + private val SparkGeneratedName = """col\d+""".r + private def isSparkGeneratedName(name: String): Boolean = name match { + case SparkGeneratedName(_*) => true + case _ => false + } + + /** + * Returns true if the write data type can be read using the read data type. + * + * The write type is compatible with the read type if: + * - Both types are arrays, the array element types are compatible, and element nullability is + * compatible (read allows nulls or write does not contain nulls). + * - Both types are maps and the map key and value types are compatible, and value nullability + * is compatible (read allows nulls or write does not contain nulls). + * - Both types are structs and each field in the read struct is present in the write struct and + * compatible (including nullability), or is nullable if the write struct does not contain the + * field. Write-side structs are not compatible if they contain fields that are not present in + * the read-side struct. + * - Both types are atomic and the write type can be safely cast to the read type. + * + * Extra fields in write-side structs are not allowed to avoid accidentally writing data that + * the read schema will not read, and to ensure map key equality is not changed when data is read. + * + * @param write a write-side data type to validate against the read type + * @param read a read-side data type + * @return true if data written with the write type can be read using the read type + */ + def canWrite( + write: DataType, + read: DataType, + resolver: Resolver, + context: String, + addError: String => Unit = (_: String) => {}): Boolean = { + (write, read) match { + case (wArr: ArrayType, rArr: ArrayType) => + // run compatibility check first to produce all error messages + val typesCompatible = + canWrite(wArr.elementType, rArr.elementType, resolver, context + ".element", addError) + + if (wArr.containsNull && !rArr.containsNull) { + addError(s"Cannot write nullable elements to array of non-nulls: '$context'") + false + } else { + typesCompatible + } + + case (wMap: MapType, rMap: MapType) => + // map keys cannot include data fields not in the read schema without changing equality when + // read. map keys can be missing fields as long as they are nullable in the read schema. + + // run compatibility check first to produce all error messages + val keyCompatible = + canWrite(wMap.keyType, rMap.keyType, resolver, context + ".key", addError) + val valueCompatible = + canWrite(wMap.valueType, rMap.valueType, resolver, context + ".value", addError) + val typesCompatible = keyCompatible && valueCompatible + + if (wMap.valueContainsNull && !rMap.valueContainsNull) { + addError(s"Cannot write nullable values to map of non-nulls: '$context'") + false + } else { + typesCompatible + } + + case (StructType(writeFields), StructType(readFields)) => + var fieldCompatible = true + readFields.zip(writeFields).foreach { + case (rField, wField) => + val namesMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name) + val fieldContext = s"$context.${rField.name}" + val typesCompatible = + canWrite(wField.dataType, rField.dataType, resolver, fieldContext, addError) + + if (!namesMatch) { + addError(s"Struct '$context' field name does not match (may be out of order): " + + s"expected '${rField.name}', found '${wField.name}'") + fieldCompatible = false + } else if (!rField.nullable && wField.nullable) { + addError(s"Cannot write nullable values to non-null field: '$fieldContext'") + fieldCompatible = false + } else if (!typesCompatible) { + // errors are added in the recursive call to canWrite above + fieldCompatible = false + } + } + + if (readFields.size > writeFields.size) { + val missingFieldsStr = readFields.takeRight(readFields.size - writeFields.size) + .map(f => s"'${f.name}'").mkString(", ") + if (missingFieldsStr.nonEmpty) { + addError(s"Struct '$context' missing fields: $missingFieldsStr") + fieldCompatible = false + } + + } else if (writeFields.size > readFields.size) { + val extraFieldsStr = writeFields.takeRight(writeFields.size - readFields.size) + .map(f => s"'${f.name}'").mkString(", ") + addError(s"Cannot write extra fields to struct '$context': $extraFieldsStr") + fieldCompatible = false + } + + fieldCompatible + + case (w: AtomicType, r: AtomicType) => + if (!Cast.canSafeCast(w, r)) { + addError(s"Cannot safely cast '$context': $w to $r") + false + } else { + true + } + + case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] => + true + + case (w, r) => + addError(s"Cannot write '$context': $w is incompatible with $r") + false + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6da4f28b12962..9eed2eb202045 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -479,6 +479,25 @@ object Decimal { dec } + // Max precision of a decimal value stored in `numBytes` bytes + def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } + + // Returns the minimum number of bytes needed to store a decimal with a given `precision`. + lazy val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) + + private def computeMinBytesForPrecision(precision : Int) : Int = { + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two // parameters inheriting from a common trait since both traits define mkNumericOps. // See scala.math's Numeric.scala for examples for Scala's built-in types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ef3b67c0d48d0..15004e4b9667d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -48,7 +48,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException(s"DecimalType can only support precision up to 38") + throw new AnalysisException( + s"${DecimalType.simpleString} can only support precision up to ${DecimalType.MAX_PRECISION}") } // default constructor for Java @@ -120,6 +121,7 @@ object DecimalType extends AbstractDataType { val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types + private[sql] val BooleanDecimal = DecimalType(1, 0) private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) private[sql] val IntDecimal = DecimalType(10, 0) @@ -161,13 +163,17 @@ object DecimalType extends AbstractDataType { * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. */ private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { - // Assumptions: + // Assumption: assert(precision >= scale) - assert(scale >= 0) if (precision <= MAX_PRECISION) { // Adjustment only needed when we exceed max precision DecimalType(precision, scale) + } else if (scale < 0) { + // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision + // loss since we would cause a loss of digits in the integer part. + // In this case, we are likely to meet an overflow. + DecimalType(MAX_PRECISION, scale) } else { // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. val intDigits = precision - scale diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala index e0bca937d1d84..4eb3226c5786e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala @@ -56,14 +56,18 @@ object HiveStringType { } /** - * Hive char type. + * Hive char type. Similar to other HiveStringType's, these datatypes should only used for + * parsing, and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. */ case class CharType(length: Int) extends HiveStringType { override def simpleString: String = s"char($length)" } /** - * Hive varchar type. + * Hive varchar type. Similar to other HiveStringType's, these datatypes should only used for + * parsing, and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. */ case class VarcharType(length: Int) extends HiveStringType { override def simpleString: String = s"varchar($length)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 6691b81dcea8d..594e155268bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -42,9 +42,9 @@ case class MapType( private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- key: ${keyType.typeName}\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) builder.append(s"$prefix-- value: ${valueType.typeName} " + s"(valueContainsNull = $valueContainsNull)\n") - DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 352fb545f4b6b..7c15dc0de4b6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -215,6 +215,8 @@ object Metadata { x.## case x: Metadata => hash(x.map) + case null => + 0 case other => throw new RuntimeException(s"Do not support type ${other.getClass}.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 2d49fe076786a..203e85e1c99bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability @InterfaceStability.Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = - throw new UnsupportedOperationException("null literals can't be casted to ObjectType") + throw new UnsupportedOperationException( + s"null literals can't be casted to ${ObjectType.simpleString}") override private[sql] def acceptsType(other: DataType): Boolean = other match { case ObjectType(_) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 2c18fdcc497fe..902cae9150ede 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -21,6 +21,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} /** * A field inside a StructType. @@ -74,4 +75,16 @@ case class StructField( def getComment(): Option[String] = { if (metadata.contains("comment")) Option(metadata.getString("comment")) else None } + + /** + * Returns a string containing a schema in DDL format. For example, the following value: + * `StructField("eventId", IntegerType)` will be converted to `eventId` INT. + */ + def toDDL: String = { + val comment = getComment() + .map(escapeSingleQuotedString) + .map(" COMMENT '" + _ + "'") + + s"${quoteIdentifier(name)} ${dataType.sql}${comment.getOrElse("")}" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 362676b252126..c5ca169c955dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.util.Utils /** @@ -360,6 +360,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"STRUCT<${fieldTypes.mkString(", ")}>" } + /** + * Returns a string containing a schema in DDL format. For example, the following value: + * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))` + * will be converted to `eventId` INT, `s` STRING. + * The returned DDL schema can be used in a table creation. + */ + def toDDL: String = fields.map(_.toDDL).mkString(",") + private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { @@ -426,7 +434,7 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") } } @@ -528,7 +536,8 @@ object StructType extends AbstractDataType { leftType case _ => - throw new SparkException(s"Failed to merge incompatible data types $left and $right") + throw new SparkException(s"Failed to merge incompatible data types ${left.catalogString}" + + s" and ${right.catalogString}") } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f3702ec92b425..89452ee05cff3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -94,4 +95,56 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) == doubleGenericArray) } + + test("converting a wrong value to the struct type") { + val structType = new StructType().add("f1", IntegerType) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(structType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to struct")) + } + + test("converting a wrong value to the map type") { + val mapType = MapType(StringType, IntegerType, false) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(mapType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to a map type with key " + + "type (string) and value type (int)")) + } + + test("converting a wrong value to the array type") { + val arrayType = ArrayType(IntegerType, true) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(arrayType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to an array of int")) + } + + test("converting a wrong value to the decimal type") { + val decimalType = DecimalType(10, 0) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(decimalType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to decimal(10,0)")) + } + + test("converting a wrong value to the string type") { + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(StringType)(0.1) + } + assert(exception.getMessage.contains("The value (0.1) of the type " + + "(java.lang.Double) cannot be converted to the string type")) + } + + test("SPARK-24571: convert Char to String") { + val chr: Char = 'X' + val converter = CatalystTypeConverters.createToCatalystConverter(StringType) + val expected = UTF8String.fromString("X") + assert(converter(chr) === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d55..39228102682b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -41,34 +41,127 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning (with nullSafe = true) is the output partitioning") { - // Cases which do not need an exchange between two data properties. + test("UnspecifiedDistribution and AllTuples") { + // except `BroadcastPartitioning`, all other partitioning can satisfy UnspecifiedDistribution checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + UnknownPartitioning(-1), UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + RoundRobinPartitioning(10), + UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + SinglePartition, + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + UnspecifiedDistribution, + false) + + // except `BroadcastPartitioning`, all other partitioning can satisfy AllTuples if they have + // only one partition. + checkSatisfied( + UnknownPartitioning(1), + AllTuples, + true) + + checkSatisfied( + UnknownPartitioning(10), + AllTuples, + false) + + checkSatisfied( + RoundRobinPartitioning(1), + AllTuples, + true) + + checkSatisfied( + RoundRobinPartitioning(10), + AllTuples, + false) + + checkSatisfied( + SinglePartition, + AllTuples, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 1), + AllTuples, true) + checkSatisfied( + HashPartitioning(Seq('a), 10), + AllTuples, + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 1), + AllTuples, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + AllTuples, + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + AllTuples, + false) + } + + test("SinglePartition is the output partitioning") { + // SinglePartition can satisfy all the distributions except `BroadcastDistribution` checkSatisfied( SinglePartition, ClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), true) - // Cases which need an exchange between two data properties. + checkSatisfied( + SinglePartition, + BroadcastDistribution(IdentityBroadcastMode), + false) + } + + test("HashPartitioning is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), ClusteredDistribution(Seq('b, 'c)), @@ -79,37 +172,43 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + // HashPartitioning can satisfy HashClusteredDistribution iff its hash expressions are exactly + // same with the required hash clustering expressions. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), - AllTuples, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('c, 'b, 'a), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), + false) + + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 1), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) + false) // TODO: this can be relaxed. - // TODO: We should check functional dependencies - /* checkSatisfied( - ClusteredDistribution(Seq('b)), - ClusteredDistribution(Seq('b + 1)), - true) - */ + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) } test("RangePartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - UnspecifiedDistribution, - true) - + // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix + // of the required ordering, or the required ordering is a prefix of its ordering. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -125,6 +224,27 @@ class DistributionSuite extends SparkFunSuite { OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), true) + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'd.desc)), + false) + + // RangePartitioning can satisfy ClusteredDistribution iff its ordering expressions are a subset + // of the required clustering expressions. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('a, 'b, 'c)), @@ -140,34 +260,47 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) - // Cases which need an exchange between two data properties. - // TODO: We can have an optimization to first sort the dataset - // by a.asc and then sort b, and c in a partition. This optimization - // should tradeoff the benefit of a less number of Exchange operators - // and the parallelism. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + ClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('b.asc, 'a.asc)), + ClusteredDistribution(Seq('c, 'd)), false) + // RangePartitioning cannot satisfy HashClusteredDistribution checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + } + test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - AllTuples, + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 353b8344658f2..f9ee948b97e0a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -261,23 +261,6 @@ class ScalaReflectionSuite extends SparkFunSuite { } } - test("get parameter type from a function object") { - val primitiveFunc = (i: Int, j: Long) => "x" - val primitiveTypes = getParameterTypes(primitiveFunc) - assert(primitiveTypes.forall(_.isPrimitive)) - assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) - - val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" - val boxedTypes = getParameterTypes(boxedFunc) - assert(boxedTypes.forall(!_.isPrimitive)) - assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) - - val anyFunc = (i: Any, j: AnyRef) => "x" - val anyTypes = getParameterTypes(anyFunc) - assert(anyTypes.forall(!_.isPrimitive)) - assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) - } - test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) val serializer = serializerFor[List[Int]](BoundReference( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala new file mode 100644 index 0000000000000..68e76fc013c18 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf.NESTED_SCHEMA_PRUNING_ENABLED + +/** + * A PlanTest that ensures that all tests in this suite are run with nested schema pruning enabled. + * Remove this trait once the default value of SQLConf.NESTED_SCHEMA_PRUNING_ENABLED is set to true. + */ +private[sql] trait SchemaPruningTest extends PlanTest with BeforeAndAfterAll { + private var originalConfSchemaPruningEnabled = false + + override protected def beforeAll(): Unit = { + originalConfSchemaPruningEnabled = conf.nestedSchemaPruningEnabled + conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, true) + super.beforeAll() + } + + override protected def afterAll(): Unit = { + try { + super.afterAll() + } finally { + conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, originalConfSchemaPruningEnabled) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5d2f8e735e3d4..94778840d706b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -277,13 +277,13 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "intersect with unequal number of columns", - testRelation.intersect(testRelation2), + testRelation.intersect(testRelation2, isAll = false), "intersect" :: "number of columns" :: testRelation2.output.length.toString :: testRelation.output.length.toString :: Nil) errorTest( "except with unequal number of columns", - testRelation.except(testRelation2), + testRelation.except(testRelation2, isAll = false), "except" :: "number of columns" :: testRelation2.output.length.toString :: testRelation.output.length.toString :: Nil) @@ -299,22 +299,22 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "intersect with incompatible column types", - testRelation.intersect(nestedRelation), + testRelation.intersect(nestedRelation, isAll = false), "intersect" :: "the compatible column types" :: Nil) errorTest( "intersect with a incompatible column type and compatible column types", - testRelation3.intersect(testRelation4), + testRelation3.intersect(testRelation4, isAll = false), "intersect" :: "the compatible column types" :: "map" :: "decimal" :: Nil) errorTest( "except with incompatible column types", - testRelation.except(nestedRelation), + testRelation.except(nestedRelation, isAll = false), "except" :: "the compatible column types" :: Nil) errorTest( "except with a incompatible column type and compatible column types", - testRelation3.except(testRelation4), + testRelation3.except(testRelation4, isAll = false), "except" :: "the compatible column types" :: "map" :: "decimal" :: Nil) errorTest( @@ -334,14 +334,28 @@ class AnalysisErrorSuite extends AnalysisTest { "start time greater than slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), - "The start time " :: " must be less than the slideDuration " :: Nil + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil ) errorTest( "start time equal to slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), - "The start time " :: " must be less than the slideDuration " :: Nil + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "SPARK-21590: absolute value of start time greater than slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 minute").as("window")), + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "SPARK-21590: absolute value of start time equal to slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 second").as("window")), + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil ) errorTest( @@ -372,13 +386,6 @@ class AnalysisErrorSuite extends AnalysisTest { "The slide duration" :: " must be greater than 0." :: Nil ) - errorTest( - "negative start time in time window", - testRelation.select( - TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")), - "The start time" :: "must be greater than or equal to 0." :: Nil - ) - errorTest( "generator nested in expressions", listRelation.select(Explode('list) + 1), @@ -392,6 +399,12 @@ class AnalysisErrorSuite extends AnalysisTest { "Generators are not supported outside the SELECT clause, but got: Sort" :: Nil ) + errorTest( + "an evaluated limit class must not be null", + testRelation.limit(Literal(null, IntegerType)), + "The evaluated limit expression must not be null, but got " :: Nil + ) + errorTest( "num_rows in limit clause must be equal to or greater than 0", listRelation.limit(-1), @@ -514,14 +527,14 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) } test("PredicateSubQuery is used outside of a filter") { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -530,12 +543,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter(Cast(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + val plan2 = Filter( + Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cd8579584eada..3b3edac0a314e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import scala.reflect.ClassTag + import org.scalatest.Matchers +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -232,7 +235,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkAnalysis(plan, expected) } - test("Analysis may leave unnecassary aliases") { + test("Analysis may leave unnecessary aliases") { val att1 = testRelation.output.head var plan = testRelation.select( CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), @@ -270,7 +273,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("self intersect should resolve duplicate expression IDs") { - val plan = testRelation.intersect(testRelation) + val plan = testRelation.intersect(testRelation, isAll = false) assertAnalysisSuccess(plan) } @@ -314,16 +317,19 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkUDF(udf1, expected1) // only primitive parameter needs special null handling - val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) - val expected2 = If(IsNull(double), nullResult, udf2) + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil, + nullableTypes = true :: false :: Nil) + val expected2 = + If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters - val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil, + nullableTypes = false :: false :: Nil) val expected3 = If( IsNull(short) || IsNull(double), nullResult, - udf3) + udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil)) checkUDF(udf3, expected3) // we can skip special null handling for primitive parameters that are not nullable @@ -331,14 +337,24 @@ class AnalysisSuite extends AnalysisTest with Matchers { val udf4 = ScalaUDF( (s: Short, d: Double) => "x", StringType, - short :: double.withNullability(false) :: Nil) + short :: double.withNullability(false) :: Nil, + nullableTypes = false :: false :: Nil) val expected4 = If( IsNull(short), nullResult, - udf4) + udf4.copy(children = KnownNotNull(short) :: double.withNullability(false) :: Nil)) // checkUDF(udf4, expected4) } + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { + val a = testRelation.output(0) + val func = (x: Int, y: Int) => x + y + val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil) + val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil) + val plan = Project(Alias(udf2, "")() :: Nil, testRelation) + comparePlans(plan.analyze, plan.analyze.analyze) + } + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { val a = testRelation2.output(0) val c = testRelation2.output(2) @@ -426,8 +442,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { val unionPlan = Union(firstTable, secondTable) assertAnalysisSuccess(unionPlan) - val r1 = Except(firstTable, secondTable) - val r2 = Intersect(firstTable, secondTable) + val r1 = Except(firstTable, secondTable, isAll = false) + val r2 = Intersect(firstTable, secondTable, isAll = false) assertAnalysisSuccess(r1) assertAnalysisSuccess(r2) @@ -518,9 +534,11 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-22614 RepartitionByExpression partitioning") { - def checkPartitioning[T <: Partitioning](numPartitions: Int, exprs: Expression*): Unit = { + def checkPartitioning[T <: Partitioning: ClassTag]( + numPartitions: Int, exprs: Expression*): Unit = { val partitioning = RepartitionByExpression(exprs, testRelation2, numPartitions).partitioning - assert(partitioning.isInstanceOf[T]) + val clazz = implicitly[ClassTag[T]].runtimeClass + assert(clazz.isInstance(partitioning)) } checkPartitioning[HashPartitioning](numPartitions = 10, exprs = Literal(20)) @@ -544,17 +562,28 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } - test("SPARK-20392: analysis barrier") { - // [[AnalysisBarrier]] will be removed after analysis - checkAnalysis( - Project(Seq(UnresolvedAttribute("tbl.a")), - AnalysisBarrier(SubqueryAlias("tbl", testRelation))), - Project(testRelation.output, SubqueryAlias("tbl", testRelation))) - - // Verify we won't go through a plan wrapped in a barrier. - // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. - val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), - SubqueryAlias("tbl", testRelation))) - assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project = Project(Seq(UnresolvedAttribute("a")), testRelation) + val flatMapGroupsInPandas = FlatMapGroupsInPandas( + Seq(UnresolvedAttribute("a")), pythonUdf, output, project) + val left = SubqueryAlias("temp0", flatMapGroupsInPandas) + val right = SubqueryAlias("temp1", flatMapGroupsInPandas) + val join = Join(left, right, Inner, None) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } + + test("SPARK-24488 Generator with multiple aliases") { + assertAnalysisSuccess( + listRelation.select(Explode('list).as("first_alias").as("second_alias"))) + assertAnalysisSuccess( + listRelation.select(MultiAlias(MultiAlias( + PosExplode('list), Seq("first_pos", "first_val")), Seq("second_pos", "second_val")))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala new file mode 100644 index 0000000000000..6c899b610ac5b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, UpCast} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType} + +case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { + override def name: String = "table-name" +} + +class DataSourceV2AnalysisSuite extends AnalysisTest { + val table = TestRelation(StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType))).toAttributes) + + val requiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))).toAttributes) + + val widerTable = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType))).toAttributes) + + test("Append.byName: basic behavior") { + val query = TestRelation(table.schema.toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + checkAnalysis(parsedPlan, parsedPlan) + assertResolved(parsedPlan) + } + + test("Append.byName: does not match by position") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'", "'y'")) + } + + test("Append.byName: case sensitive column resolution") { + val query = TestRelation(StructType(Seq( + StructField("X", FloatType), // doesn't match case! + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'"), + caseSensitive = true) + } + + test("Append.byName: case insensitive column resolution") { + val query = TestRelation(StructType(Seq( + StructField("X", FloatType), // doesn't match case! + StructField("y", FloatType))).toAttributes) + + val X = query.output.head + val y = query.output.last + + val parsedPlan = AppendData.byName(table, query) + val expectedPlan = AppendData.byName(table, + Project(Seq( + Alias(Cast(toLower(X), FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan, caseSensitive = false) + assertResolved(expectedPlan) + } + + test("Append.byName: data columns are reordered by name") { + // out of order + val query = TestRelation(StructType(Seq( + StructField("y", FloatType), + StructField("x", FloatType))).toAttributes) + + val y = query.output.head + val x = query.output.last + + val parsedPlan = AppendData.byName(table, query) + val expectedPlan = AppendData.byName(table, + Project(Seq( + Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byName: fail nullable data written to required columns") { + val parsedPlan = AppendData.byName(requiredTable, table) + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", "'y'")) + } + + test("Append.byName: allow required data written to nullable columns") { + val parsedPlan = AppendData.byName(table, requiredTable) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + + test("Append.byName: missing required columns cause failure and are identified by name") { + // missing required field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType, nullable = false))).toAttributes) + + val parsedPlan = AppendData.byName(requiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'")) + } + + test("Append.byName: missing optional columns cause failure and are identified by name") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'")) + } + + test("Append.byName: fail canWrite check") { + val parsedPlan = AppendData.byName(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("Append.byName: insert safe cast") { + val x = table.output.head + val y = table.output.last + + val parsedPlan = AppendData.byName(widerTable, table) + val expectedPlan = AppendData.byName(widerTable, + Project(Seq( + Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), + table)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byName: fail extra data fields") { + val query = TestRelation(StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType), + StructField("z", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "too many data columns", + "Table columns: 'x', 'y'", + "Data columns: 'x', 'y', 'z'")) + } + + test("Append.byName: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot safely cast", "'x'", "DoubleType to FloatType", + "Cannot write nullable values to non-null column", "'x'", + "Cannot find data for output column", "'y'")) + } + + test("Append.byPosition: basic behavior") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType))).toAttributes) + + val a = query.output.head + val b = query.output.last + + val parsedPlan = AppendData.byPosition(table, query) + val expectedPlan = AppendData.byPosition(table, + Project(Seq( + Alias(Cast(a, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(b, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan, caseSensitive = false) + assertResolved(expectedPlan) + } + + test("Append.byPosition: data columns are not reordered") { + // out of order + val query = TestRelation(StructType(Seq( + StructField("y", FloatType), + StructField("x", FloatType))).toAttributes) + + val y = query.output.head + val x = query.output.last + + val parsedPlan = AppendData.byPosition(table, query) + val expectedPlan = AppendData.byPosition(table, + Project(Seq( + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byPosition: fail nullable data written to required columns") { + val parsedPlan = AppendData.byPosition(requiredTable, table) + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", "'y'")) + } + + test("Append.byPosition: allow required data written to nullable columns") { + val parsedPlan = AppendData.byPosition(table, requiredTable) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + + test("Append.byPosition: missing required columns cause failure") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType, nullable = false))).toAttributes) + + val parsedPlan = AppendData.byPosition(requiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "not enough data columns", + "Table columns: 'x', 'y'", + "Data columns: 'y'")) + } + + test("Append.byPosition: missing optional columns cause failure") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byPosition(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "not enough data columns", + "Table columns: 'x', 'y'", + "Data columns: 'y'")) + } + + test("Append.byPosition: fail canWrite check") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val parsedPlan = AppendData.byPosition(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("Append.byPosition: insert safe cast") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val x = table.output.head + val y = table.output.last + + val parsedPlan = AppendData.byPosition(widerTable, table) + val expectedPlan = AppendData.byPosition(widerTable, + Project(Seq( + Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "a")(), + Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "b")()), + table)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byPosition: fail extra data fields") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType), + StructField("c", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "too many data columns", + "Table columns: 'x', 'y'", + "Data columns: 'a', 'b', 'c'")) + } + + test("Append.byPosition: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byPosition(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", + "Cannot safely cast", "'x'", "DoubleType to FloatType")) + } + + def assertNotResolved(logicalPlan: LogicalPlan): Unit = { + assert(!logicalPlan.resolved, s"Plan should not be resolved: $logicalPlan") + } + + def assertResolved(logicalPlan: LogicalPlan): Unit = { + assert(logicalPlan.resolved, s"Plan should be resolved: $logicalPlan") + } + + def toLower(attr: AttributeReference): AttributeReference = { + AttributeReference(attr.name.toLowerCase(Locale.ROOT), attr.dataType)(attr.exprId) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index c86dc18dfa680..bd87ca6017e99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -272,6 +272,15 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { } } + test("SPARK-24468: operations on decimals with negative scale") { + val a = AttributeReference("a", DecimalType(3, -10))() + val b = AttributeReference("b", DecimalType(1, -1))() + val c = AttributeReference("c", DecimalType(35, 1))() + checkType(Multiply(a, b), DecimalType(5, -11)) + checkType(Multiply(a, c), DecimalType(38, -9)) + checkType(Multiply(b, c), DecimalType(37, 0)) + } + /** strength reduction for integer/decimal comparisons */ def ruleTest(initial: Expression, transformed: Expression): Unit = { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 36714bd631b0e..8eec14842c7e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -109,17 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type map") assertError(EqualNullSafe('mapField, 'mapField), - "EqualNullSafe does not support ordering on type MapType") + "EqualNullSafe does not support ordering on type map") assertError(LessThan('mapField, 'mapField), - "LessThan does not support ordering on type MapType") + "LessThan does not support ordering on type map") assertError(LessThanOrEqual('mapField, 'mapField), - "LessThanOrEqual does not support ordering on type MapType") + "LessThanOrEqual does not support ordering on type map") assertError(GreaterThan('mapField, 'mapField), - "GreaterThan does not support ordering on type MapType") + "GreaterThan does not support ordering on type map") assertError(GreaterThanOrEqual('mapField, 'mapField), - "GreaterThanOrEqual does not support ordering on type MapType") + "GreaterThanOrEqual does not support ordering on type map") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") @@ -169,10 +169,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), "Field name should not be null") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala new file mode 100644 index 0000000000000..cea0f2a9cbc97 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.net.URI + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +class LookupFunctionsSuite extends PlanTest { + + test("SPARK-23486: the functionExists for the Persistent function check") { + val externalCatalog = new CustomInMemoryCatalog + val conf = new SQLConf() + val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(), + Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(), + Alias(unresolvedRegisteredFunc, "call5")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(externalCatalog.getFunctionExistsCalledTimes == 1) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedPersistentFunc.name).database == Some("default")) + } + + test("SPARK-23486: the functionExists for the Registered function check") { + val externalCatalog = new InMemoryCatalog + val conf = new SQLConf() + val customerFunctionReg = new CustomerFunctionRegistry + val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedRegisteredFunc.name).database == Some("default")) + } +} + +class CustomerFunctionRegistry extends SimpleFunctionRegistry { + + private var isRegisteredFunctionCalledTimes: Int = 0; + + override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized { + isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1 + true + } + + def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes +} + +class CustomInMemoryCatalog extends InMemoryCatalog { + + private var functionExistsCalledTimes: Int = 0 + + override def functionExists(db: String, funcName: String): Boolean = synchronized { + functionExistsCalledTimes = functionExistsCalledTimes + 1 + true + } + + def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 553b1598e7750..8da4d7e3aa372 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -91,6 +91,34 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list")) } + test("grouping sets with no explicit group by expressions") { + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Nil, r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + // Computation of grouping expression should remove duplicate expression based on their + // semantics (semanticEqual). + val originalPlan2 = GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))), + Seq(Multiply(Literal(2), unresolved_a), unresolved_b)), Nil, r1, + Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), + unresolved_b, UnresolvedAlias(count(unresolved_c)))) + + val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2) + val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions + assert(gExpressions.size == 3) + val firstGroupingExprAttrName = + gExpressions(0).asInstanceOf[AttributeReference].name.replaceAll("#[0-9]*", "#0") + assert(firstGroupingExprAttrName == "(a#0 * 2)") + assert(gExpressions(1).asInstanceOf[AttributeReference].name == "b") + assert(gExpressions(2).asInstanceOf[AttributeReference].name == VirtualColumn.groupingIdName) + } + test("cube") { val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 9782b5fb0d266..bd66ee5355f45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ @@ -120,4 +121,38 @@ class ResolveHintsSuite extends AnalysisTest { testRelation.where('a > 1).select('a).select('a).analyze, caseSensitive = false) } + + test("coalesce and repartition hint") { + checkAnalysis( + UnresolvedHint("COALESCE", Seq(Literal(10)), table("TaBlE")), + Repartition(numPartitions = 10, shuffle = false, child = testRelation)) + checkAnalysis( + UnresolvedHint("coalesce", Seq(Literal(20)), table("TaBlE")), + Repartition(numPartitions = 20, shuffle = false, child = testRelation)) + checkAnalysis( + UnresolvedHint("REPARTITION", Seq(Literal(100)), table("TaBlE")), + Repartition(numPartitions = 100, shuffle = true, child = testRelation)) + checkAnalysis( + UnresolvedHint("RePARTITion", Seq(Literal(200)), table("TaBlE")), + Repartition(numPartitions = 200, shuffle = true, child = testRelation)) + + val errMsgCoal = "COALESCE Hint expects a partition number as parameter" + assertAnalysisError( + UnresolvedHint("COALESCE", Seq.empty, table("TaBlE")), + Seq(errMsgCoal)) + assertAnalysisError( + UnresolvedHint("COALESCE", Seq(Literal(10), Literal(false)), table("TaBlE")), + Seq(errMsgCoal)) + assertAnalysisError( + UnresolvedHint("COALESCE", Seq(Literal(1.0)), table("TaBlE")), + Seq(errMsgCoal)) + + val errMsgRepa = "REPARTITION Hint expects a partition number as parameter" + assertAnalysisError( + UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")), + Seq(errMsgRepa)) + assertAnalysisError( + UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")), + Seq(errMsgRepa)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala new file mode 100644 index 0000000000000..c4171c75ecd03 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{ArrayType, IntegerType} + +/** + * Test suite for [[ResolveLambdaVariables]]. + */ +class ResolveLambdaVariablesSuite extends PlanTest { + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + object Analyzer extends RuleExecutor[LogicalPlan] { + val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables(conf)) :: Nil + } + + private val key = 'key.int + private val values1 = 'values1.array(IntegerType) + private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType))) + private val data = LocalRelation(Seq(key, values1, values2)) + private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true) + private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true) + private val lvArray = NamedLambdaVariable("x", ArrayType(IntegerType), nullable = true) + + private def plan(e: Expression): LogicalPlan = data.select(e.as("res")) + + private def checkExpression(e1: Expression, e2: Expression): Unit = { + comparePlans(Analyzer.execute(plan(e1)), plan(e2)) + } + + test("resolution - no op") { + checkExpression(key, key) + } + + test("resolution - simple") { + val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil)) + val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil)) + checkExpression(in, out) + } + + test("resolution - nested") { + val in = ArrayTransform(values2, LambdaFunction( + ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil)) + val out = ArrayTransform(values2, LambdaFunction( + ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil)) + checkExpression(in, out) + } + + test("resolution - hidden") { + val in = ArrayTransform(values1, key) + val out = ArrayTransform(values1, LambdaFunction(key, lvHiddenInt :: Nil, hidden = true)) + checkExpression(in, out) + } + + test("fail - name collisions") { + val p = plan(ArrayTransform(values1, + LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil))) + val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage + assert(msg.contains("arguments should not have names that are semantically the same")) + } + + test("fail - lambda arguments") { + val p = plan(ArrayTransform(values1, + LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil))) + val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage + assert(msg.contains("does not match the number of arguments expected")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 1bf8d76da04d8..74a8590b5eefe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -33,7 +33,8 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) + val expr = Filter( + InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index fd6a3121663ed..461eda4334bb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,8 +54,9 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable when all the internal child types are castable according to the table. // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { @@ -396,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest { widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), StructType(Seq(StructField("a", DoubleType, nullable = false))), - None) + Some(StructType(Seq(StructField("a", DoubleType, nullable = false))))) widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), @@ -429,21 +430,42 @@ class TypeCoercionSuite extends AnalysisTest { Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), isSymmetric = false) } + + widenTest( + ArrayType(IntegerType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + MapType(IntegerType, StringType, valueContainsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + Some(MapType(IntegerType, StringType, valueContainsNull = true))) + + widenTest( + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false), + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true), + Some(new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true))) } test("wider common type for decimal and array") { def widenTestWithStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric) } def widenTestWithoutStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType( + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric) } // Decimal @@ -469,12 +491,140 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(ArrayType(IntegerType), containsNull = false), ArrayType(ArrayType(LongType), containsNull = false), Some(ArrayType(ArrayType(LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(MapType(IntegerType, FloatType), containsNull = false), + ArrayType(MapType(LongType, DoubleType), containsNull = false), + Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(new StructType().add("num", ShortType), containsNull = false), + ArrayType(new StructType().add("num", LongType), containsNull = false), + Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(IntegerType, containsNull = false), + ArrayType(DecimalType.IntDecimal, containsNull = false), + Some(ArrayType(DecimalType.IntDecimal, containsNull = false))) + widenTestWithStringPromotion( + ArrayType(DecimalType(36, 0), containsNull = false), + ArrayType(DecimalType(36, 35), containsNull = false), + Some(ArrayType(DecimalType(38, 35), containsNull = true))) + + // MapType + widenTestWithStringPromotion( + MapType(ShortType, TimestampType, valueContainsNull = true), + MapType(DoubleType, StringType, valueContainsNull = false), + Some(MapType(DoubleType, StringType, valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, ArrayType(TimestampType), valueContainsNull = false), + MapType(LongType, ArrayType(StringType), valueContainsNull = true), + Some(MapType(LongType, ArrayType(StringType), valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), + MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), + Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), + MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), + Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType, valueContainsNull = false), + MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false), + Some(MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(StringType, DecimalType(36, 0), valueContainsNull = false), + MapType(StringType, DecimalType(36, 35), valueContainsNull = false), + Some(MapType(StringType, DecimalType(38, 35), valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType, valueContainsNull = false), + MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false), + Some(MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(DecimalType(36, 0), StringType, valueContainsNull = false), + MapType(DecimalType(36, 35), StringType, valueContainsNull = false), + None) + + // StructType + widenTestWithStringPromotion( + new StructType() + .add("num", ShortType, nullable = true).add("ts", StringType, nullable = false), + new StructType() + .add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true), + Some(new StructType() + .add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true))) + widenTestWithStringPromotion( + new StructType() + .add("arr", ArrayType(ShortType, containsNull = false), nullable = false), + new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false), + Some(new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType() + .add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false), + new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), + Some(new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType().add("num", IntegerType, nullable = false), + new StructType().add("num", DecimalType.IntDecimal, nullable = false), + Some(new StructType().add("num", DecimalType.IntDecimal, nullable = false))) + widenTestWithStringPromotion( + new StructType().add("num", DecimalType(36, 0), nullable = false), + new StructType().add("num", DecimalType(36, 35), nullable = false), + Some(new StructType().add("num", DecimalType(38, 35), nullable = true))) + + widenTestWithStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + } // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) widenTestWithoutStringPromotion(StringType, TimestampType, None) widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + widenTestWithoutStringPromotion( + MapType(LongType, IntegerType), MapType(StringType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, LongType), MapType(IntegerType, StringType), None) + widenTestWithoutStringPromotion( + MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -483,6 +633,30 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) widenTestWithStringPromotion( ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) + widenTestWithStringPromotion( + MapType(LongType, IntegerType), + MapType(StringType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, LongType), + MapType(IntegerType, StringType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType), + MapType(TimestampType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType), + MapType(IntegerType, TimestampType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + Some(new StructType().add("a", StringType))) + widenTestWithStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + Some(new StructType().add("a", StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { @@ -506,11 +680,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -518,11 +692,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } @@ -545,46 +719,43 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), - Coalesce(Seq(Cast(doubleLit, DoubleType), - Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + Coalesce(Seq(doubleLit, Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) ruleTest(rule, Coalesce(Seq(longLit, intLit, decimalLit)), Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), - Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + Cast(intLit, DecimalType(22, 0)), decimalLit))) ruleTest(rule, Coalesce(Seq(nullLit, intLit)), - Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + Coalesce(Seq(Cast(nullLit, IntegerType), intLit))) ruleTest(rule, Coalesce(Seq(timestampLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, intLit)), - Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), - Cast(intLit, FloatType)))) + Coalesce(Seq(Cast(nullLit, FloatType), floatNullLit, Cast(intLit, FloatType)))) ruleTest(rule, Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), - Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + Cast(decimalLit, DoubleType), doubleLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), - Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + Cast(doubleLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(timestampLit, intLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), - Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), - Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) + Cast(intArrayLit, ArrayType(StringType)), strArrayLit))) } test("CreateArray casts") { @@ -593,7 +764,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - CreateArray(Cast(Literal(1.0), DoubleType) + CreateArray(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -605,7 +776,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Cast(Literal(1.0), StringType) :: Cast(Literal(1), StringType) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -623,7 +794,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38)) :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38)) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) } @@ -637,7 +808,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateMap(Cast(Literal(1), FloatType) :: Literal("a") - :: Cast(Literal.create(2.0, FloatType), FloatType) + :: Literal.create(2.0, FloatType) :: Literal("b") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -659,7 +830,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Literal(1) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) @@ -672,7 +843,7 @@ class TypeCoercionSuite extends AnalysisTest { CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38)) :: Literal(2) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -682,8 +853,8 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Cast(Literal(1), DoubleType) - :: Cast(Literal("a"), StringType) - :: Cast(Literal(2.0), DoubleType) + :: Literal("a") + :: Literal(2.0) :: Cast(Literal(3.0), StringType) :: Nil)) } @@ -695,7 +866,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - operator(Cast(Literal(1.0), DoubleType) + operator(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -706,14 +877,14 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), operator(Cast(Literal(1L), DecimalType(22, 0)) :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) :: Nil), - operator(Literal(1.0).cast(DoubleType) + operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) @@ -805,7 +976,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -1084,8 +1255,10 @@ class TypeCoercionSuite extends AnalysisTest { val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except] - val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes( + Except(firstTable, secondTable, isAll = false)).asInstanceOf[Except] + val r2 = widenSetOperationTypes( + Intersect(firstTable, secondTable, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) @@ -1150,8 +1323,10 @@ class TypeCoercionSuite extends AnalysisTest { val expectedType1 = Seq(DecimalType(10, 8)) val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] - val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except] - val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect] + val r2 = widenSetOperationTypes( + Except(left1, right1, isAll = false)).asInstanceOf[Except] + val r3 = widenSetOperationTypes( + Intersect(left1, right1, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedType1) checkOutput(r1.children.last, expectedType1) @@ -1171,16 +1346,20 @@ class TypeCoercionSuite extends AnalysisTest { AttributeReference("r", rType)()) val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] - val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except] - val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect] + val r2 = widenSetOperationTypes( + Except(plan1, plan2, isAll = false)).asInstanceOf[Except] + val r3 = widenSetOperationTypes( + Intersect(plan1, plan2, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] - val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except] - val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect] + val r5 = widenSetOperationTypes( + Except(plan2, plan1, isAll = false)).asInstanceOf[Except] + val r6 = widenSetOperationTypes( + Intersect(plan2, plan1, isAll = false)).asInstanceOf[Intersect] checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) @@ -1257,7 +1436,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 60d1351fda264..28a164b5d0cad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -575,14 +575,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Except: *-stream not supported testBinaryOperationInStreamingPlan( "except", - _.except(_), + _.except(_, isAll = false), streamStreamSupported = false, batchStreamSupported = false) // Intersect: stream-stream not supported testBinaryOperationInStreamingPlan( "intersect", - _.intersect(_), + _.intersect(_, isAll = false), streamStreamSupported = false) // Sort: supported only on batch subplans and after aggregation on streaming plan + complete mode @@ -621,6 +621,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("monotonically_increasing_id")) + assertSupportedForContinuousProcessing( + "TypedFilter", TypedFilter( + null, + null, + null, + null, + new TestStreamingRelationV2(attribute)), OutputMode.Append()) /* ======================================================================================= @@ -759,7 +766,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { * * To test this correctly, the given logical plan is wrapped in a fake operator that makes the * whole plan look like a streaming plan. Otherwise, a batch plan may throw not supported - * exception simply for not being a streaming plan, even though that plan could exists as batch + * exception simply for not being a streaming plan, even though that plan could exist as batch * subplan inside some streaming plan. */ def assertSupportedInStreamingPlan( @@ -771,12 +778,22 @@ class UnsupportedOperationsSuite extends SparkFunSuite { } } + /** Assert that the logical plan is supported for continuous procsssing mode */ + def assertSupportedForContinuousProcessing( + name: String, + plan: LogicalPlan, + outputMode: OutputMode): Unit = { + test(s"continuous processing - $name: supported") { + UnsupportedOperationChecker.checkForContinuous(plan, outputMode) + } + } + /** * Assert that the logical plan is not supported inside a streaming plan. * * To test this correctly, the given logical plan is wrapped in a fake operator that makes the * whole plan look like a streaming plan. Otherwise, a batch plan may throw not supported - * exception simply for not being a streaming plan, even though that plan could exists as batch + * exception simply for not being a streaming plan, even though that plan could exist as batch * subplan inside some streaming plan. */ def assertNotSupportedInStreamingPlan( @@ -840,4 +857,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { def this(attribute: Attribute) = this(Seq(attribute)) override def isStreaming: Boolean = true } + + case class TestStreamingRelationV2(output: Seq[Attribute]) extends LeafNode { + def this(attribute: Attribute) = this(Seq(attribute)) + override def isStreaming: Boolean = true + override def nodeName: String = "StreamingRelationV2" + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 1acbe34d9a075..2fcaeca34db3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -36,7 +36,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite { private def testWithCatalog( name: String)( f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { - val catalog = newCatalog + val catalog = new ExternalCatalogWithListener(newCatalog) val recorder = mutable.Buffer.empty[ExternalCatalogEvent] catalog.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 6abab0073cca3..89fabd4774065 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -537,11 +537,11 @@ abstract class SessionCatalogSuite extends AnalysisTest { val view = View(desc = metadata, output = metadata.schema.toAttributes, child = CatalystSqlParser.parsePlan(metadata.viewText.get)) comparePlans(catalog.lookupRelation(TableIdentifier("view1", Some("db3"))), - SubqueryAlias("view1", view)) + SubqueryAlias("view1", "db3", view)) // Look up a view using current database of the session catalog. catalog.setCurrentDatabase("db3") comparePlans(catalog.lookupRelation(TableIdentifier("view1")), - SubqueryAlias("view1", view)) + SubqueryAlias("view1", "db3", view)) } } @@ -1114,11 +1114,13 @@ abstract class SessionCatalogSuite extends AnalysisTest { // And for hive serde table, hive metastore will set some values(e.g.transient_lastDdlTime) // in table's parameters and storage's properties, here we also ignore them. val actualPartsNormalize = actualParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet val expectedPartsNormalize = expectedParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet actualPartsNormalize == expectedPartsNormalize @@ -1215,6 +1217,42 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } + test("isRegisteredFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isRegisteredFunction(FunctionIdentifier("temp1"))) + + // Returns true when the function does register + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc1) ) + assert(catalog.isRegisteredFunction(FunctionIdentifier("iff"))) + + // Returns false when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum"))) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum", Some("db2")))) + } + } + + test("isPersistentFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isPersistentFunction(FunctionIdentifier("temp2"))) + + // Returns false when the function does register + val tempFunc2 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc2)) + assert(!catalog.isPersistentFunction(FunctionIdentifier("iff"))) + + // Return true when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(catalog.isPersistentFunction(FunctionIdentifier("sum", Some("db2")))) + assert(!catalog.isPersistentFunction(FunctionIdentifier("db2.sum"))) + } + } + test("drop function") { withBasicCatalog { catalog => assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 630113ce2d948..dd20e6497fbb4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -144,7 +144,7 @@ class EncoderResolutionSuite extends PlanTest { // It should pass analysis val bound = encoder.resolveAndBind(attrs) - // If no null values appear, it should works fine + // If no null values appear, it should work fine bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) // If there is null value, it should throw runtime exception diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index e6d09bdae67d7..f0d61de97ffcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ @@ -112,7 +112,7 @@ object ReferenceValueClass { case class Container(data: Int) } -class ExpressionEncoderSuite extends PlanTest with AnalysisTest { +class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = verifyNotLeakingReflectionObjects { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 6ed175f86ca77..8d89f9c6c41d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.encoders import scala.util.Random -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -71,7 +71,7 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] { private[spark] override def asNullable: ExamplePointUDT = this } -class RowEncoderSuite extends SparkFunSuite { +class RowEncoderSuite extends CodegenInterpretedPlanTest { private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6edb4348f8309..9a752af523ffc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -282,6 +282,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) } + + val least = Least(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(least.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(least, Seq(1, 2)) } test("function greatest") { @@ -334,10 +340,16 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } + + val greatest = Greatest(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(greatest.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(greatest, Seq(1, 3, null)) } test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") { - val N = 3000 + val N = 2000 val strings = (1 to N).map(x => "s" * x) val inputsExpr = strings.map(Literal.create(_, StringType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala new file mode 100644 index 0000000000000..28e6940f3cca3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Range + +class CanonicalizeSuite extends SparkFunSuite { + + test("SPARK-24276: IN expression with different order are semantically equal") { + val range = Range(1, 1, 1, 1) + val idAttr = range.output.head + + val in1 = In(idAttr, Seq(Literal(1), Literal(2))) + val in2 = In(idAttr, Seq(Literal(2), Literal(1))) + val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + + assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) + assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) + + assert(range.where(in1).sameResult(range.where(in2))) + assert(!range.where(in1).sameResult(range.where(in3))) + + val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(2), Literal(1))))) + val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + CreateArray(Seq(Literal(1), Literal(2))))) + val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(3), Literal(1))))) + + assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) + assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash()) + + assert(range.where(arrays1).sameResult(range.where(arrays2))) + assert(!range.where(arrays1).sameResult(range.where(arrays3))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5b25bdf907c3a..d9f32c000a885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -399,21 +399,35 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("casting to fixed-precision decimals") { - // Overflow and rounding for casting to fixed-precision decimals: - // - Values should round with HALF_UP mode by default when you lower scale - // - Values that would overflow the target precision should turn into null - // - Because of this, casts to fixed-precision decimals should be nullable - - assert(cast(123, DecimalType.USER_DEFAULT).nullable === true) + assert(cast(123, DecimalType.USER_DEFAULT).nullable === false) assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true) assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true) - assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === false) assert(cast(123, DecimalType(2, 1)).nullable === true) assert(cast(10.03f, DecimalType(2, 1)).nullable === true) assert(cast(10.03, DecimalType(2, 1)).nullable === true) assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) + assert(cast(123, DecimalType.IntDecimal).nullable === false) + assert(cast(10.03f, DecimalType.FloatDecimal).nullable === true) + assert(cast(10.03, DecimalType.DoubleDecimal).nullable === true) + assert(cast(Decimal(10.03), DecimalType(4, 2)).nullable === false) + assert(cast(Decimal(10.03), DecimalType(5, 3)).nullable === false) + + assert(cast(Decimal(10.03), DecimalType(3, 1)).nullable === true) + assert(cast(Decimal(10.03), DecimalType(4, 1)).nullable === false) + assert(cast(Decimal(9.95), DecimalType(2, 1)).nullable === true) + assert(cast(Decimal(9.95), DecimalType(3, 1)).nullable === false) + + assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable === true) + assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false) + assert(cast(Decimal("995"), DecimalType(2, -1)).nullable === true) + assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false) + + assert(cast(true, DecimalType.SYSTEM_DEFAULT).nullable === false) + assert(cast(true, DecimalType(1, 1)).nullable === true) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) @@ -451,6 +465,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType.SYSTEM_DEFAULT), Decimal(1003)) + checkEvaluation(cast(Decimal("1003"), DecimalType(4, 0)), Decimal(1003)) + checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000)) + checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000)) + checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType(3, 0)), null) + + checkEvaluation(cast(Decimal("995"), DecimalType(3, 0)), Decimal(995)) + checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000)) + checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000)) + checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null) + checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null) + checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null) @@ -460,6 +488,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null) checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null) + + checkEvaluation(cast(true, DecimalType(2, 1)), Decimal(1)) + checkEvaluation(cast(true, DecimalType(1, 1)), null) } test("cast from date") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 5b71becee2de0..c383eec3d56b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import org.apache.log4j.{Appender, AppenderSkeleton, Logger} +import org.apache.log4j.spi.LoggingEvent + import org.apache.spark.SparkFunSuite import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ @@ -499,4 +503,64 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil assert(names2.distinct.length == 4) } + + test("SPARK-25113: should log when there exists generated methods above HugeMethodLimit") { + class MockAppender extends AppenderSkeleton { + var seenMessage = false + + override def append(loggingEvent: LoggingEvent): Unit = { + if (loggingEvent.getRenderedMessage().contains("Generated method too long")) { + seenMessage = true + } + } + + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + + val appender = new MockAppender() + withLogAppender(appender) { + val x = 42 + val expr = HugeCodeIntExpression(x) + val proj = GenerateUnsafeProjection.generate(Seq(expr)) + val actual = proj(null) + assert(actual.getInt(0) == x) + } + assert(appender.seenMessage) + } + + private def withLogAppender(appender: Appender)(f: => Unit): Unit = { + val logger = + Logger.getLogger(classOf[CodeGenerator[_, _]].getName) + logger.addAppender(appender) + try f finally { + logger.removeAppender(appender) + } + } +} + +case class HugeCodeIntExpression(value: Int) extends Expression { + override def nullable: Boolean = true + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Nil + override def eval(input: InternalRow): Any = value + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Assuming HugeMethodLimit to be 8000 + val HugeMethodLimit = CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT + // A single "int dummyN = 0;" will be at least 2 bytes of bytecode: + // 0: iconst_0 + // 1: istore_1 + // and it'll become bigger as the number of local variables increases. + // So 4000 such dummy local variable definitions are sufficient to bump the bytecode size + // of a generated method to above 8000 bytes. + val hugeCode = (0 until (HugeMethodLimit / 2)).map(i => s"int dummy$i = 0;").mkString("\n") + val code = + code"""{ + | $hugeCode + |} + |boolean ${ev.isNull} = false; + |int ${ev.value} = $value; + """.stripMargin + ev.copy(code = code) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala new file mode 100644 index 0000000000000..28edd85ab6e87 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.concurrent.ExecutionException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.IntegerType + +class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { + + object FailedCodegenProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + val invalidCode = new CodeAndComment("invalid code", Map.empty) + // We assume this compilation throws an exception + CodeGenerator.compile(invalidCode) + null + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + } + + test("UnsafeProjection with codegen factory mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val obj = UnsafeProjection.createObject(input) + assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection")) + } + + val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val obj = UnsafeProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } + + test("fallback to the interpreter mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val fallback = CodegenObjectFactoryMode.FALLBACK.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallback) { + val obj = FailedCodegenProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } + + test("codegen failures in the CODEGEN_ONLY mode") { + val errMsg = intercept[ExecutionException] { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + FailedCodegenProjection.createObject(input) + } + }.getMessage + assert(errMsg.contains("failed to compile: org.codehaus.commons.compiler.CompileException:")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 43c5dda2e4a48..7b345aabd19c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -17,12 +17,23 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +import scala.util.Random + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH +import org.apache.spark.unsafe.types.CalendarInterval class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("Array and Map Size") { + def testSize(sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) @@ -39,8 +50,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Size(m1), 0) checkEvaluation(Size(m2), 1) - checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) - checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) + checkEvaluation( + Size(Literal.create(null, MapType(StringType, StringType))), + expected = sizeOfNull) + checkEvaluation( + Size(Literal.create(null, ArrayType(StringType))), + expected = sizeOfNull) + } + + test("Array and Map Size - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSize(sizeOfNull = -1) + } + } + + test("Array and Map Size") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSize(sizeOfNull = null) + } } test("MapKeys/MapValues") { @@ -56,33 +83,297 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapValues(m2), null) } + test("MapEntries") { + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys/values + val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) + val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) + val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + val mid0 = Literal.create(Map(1 -> 1.1, 2 -> 2.2), MapType(IntegerType, DoubleType)) + + checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) + checkEvaluation(MapEntries(mi1), Seq.empty) + checkEvaluation(MapEntries(mi2), null) + checkEvaluation(MapEntries(mid0), Seq(r(1, 1.1), r(2, 2.2))) + + // Non-primitive-type keys/values + val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) + val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType)) + val ms2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) + checkEvaluation(MapEntries(ms1), Seq.empty) + checkEvaluation(MapEntries(ms2), null) + } + + test("Map Concat") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + valueContainsNull = false)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + valueContainsNull = false)) + val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + MapType(ArrayType(IntegerType), IntegerType)) + val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + MapType(ArrayType(IntegerType), IntegerType)) + val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m13 = Literal.create(Map(1 -> 2, 3 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val m14 = Literal.create(Map(5 -> 6), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val m15 = Literal.create(Map(7 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val mNull = Literal.create(null, MapType(StringType, StringType)) + + // overlapping maps + checkEvaluation(MapConcat(Seq(m0, m1)), + ( + Array("a", "b", "c", "a"), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // maps with no overlap + checkEvaluation(MapConcat(Seq(m0, m2)), + Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + + // 3 maps + checkEvaluation(MapConcat(Seq(m0, m1, m2)), + ( + Array("a", "b", "c", "a", "d", "e"), // keys + Array("1", "2", "3", "4", "4", "5") // values + ) + ) + + // null reference values + checkEvaluation(MapConcat(Seq(m3, m4)), + ( + Array("a", "b", "a", "c"), // keys + Array("1", "2", null, "3") // values + ) + ) + + // null primitive values + checkEvaluation(MapConcat(Seq(m5, m6)), + ( + Array("a", "b", "a", "c"), // keys + Array(1, 2, null, 3) // values + ) + ) + + // keys that are primitive + checkEvaluation(MapConcat(Seq(m11, m12)), + ( + Array(1, 2, 3, 4), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // keys that are arrays, with overlap + checkEvaluation(MapConcat(Seq(m7, m8)), + ( + Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // keys that are maps, with overlap + checkEvaluation(MapConcat(Seq(m9, m10)), + ( + Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), + Map(1 -> 2, 3 -> 4)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // both keys and value are primitive and valueContainsNull = false + checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6)) + + // both keys and value are primitive and valueContainsNull = true + checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null)) + + // null map + checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), null) + + // single map + checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + + // no map + checkEvaluation(MapConcat(Seq.empty), Map.empty) + + // force split expressions for input in generated code + val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") + val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") + checkEvaluation(MapConcat( + Seq( + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 + )), + (expectedKeys, expectedValues)) + + // argument checking + assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) + assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull) + assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq.empty).dataType.keyType == StringType) + assert(MapConcat(Seq.empty).dataType.valueType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) + assert(!MapConcat(Seq(m1, m2)).nullable) + assert(MapConcat(Seq(m1, mNull)).nullable) + + val mapConcat = MapConcat(Seq( + Literal.create(Map(Seq(1, 2) -> Seq("a", "b")), + MapType( + ArrayType(IntegerType, containsNull = false), + ArrayType(StringType, containsNull = false), + valueContainsNull = false)), + Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)))) + assert(mapConcat.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)) + checkEvaluation(mapConcat, Map( + Seq(1, 2) -> Seq("a", "b"), + Seq(3, 4, null) -> Seq("c", "d", null), + Seq(6) -> null)) + } + + test("MapFromEntries") { + def arrayType(keyType: DataType, valueType: DataType) : DataType = { + ArrayType( + StructType(Seq( + StructField("a", keyType), + StructField("b", valueType))), + true) + } + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys and values + val aiType = arrayType(IntegerType, IntegerType) + val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) + val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai2 = Literal.create(Seq.empty, aiType) + val ai3 = Literal.create(null, aiType) + val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) + val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + + checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai2), Map.empty) + checkEvaluation(MapFromEntries(ai3), null) + checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(ai6), null) + + // Non-primitive-type keys and values + val asType = arrayType(StringType, StringType) + val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) + val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as2 = Literal.create(Seq.empty, asType) + val as3 = Literal.create(null, asType) + val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) + val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + + checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as2), Map.empty) + checkEvaluation(MapFromEntries(as3), null) + checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(as6), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(new SortArray(a4), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) - checkEvaluation(new SortArray(a4), Seq(null, null)) + checkEvaluation(new SortArray(a5), Seq(null, null)) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + checkEvaluation(new SortArray(arrayArray), Seq(aa1, aa2)) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a1), Seq[Integer]()) + checkEvaluation(ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a3), Seq("a", "b", null)) + checkEvaluation(ArraySort(a4), Seq(d1, d2)) + checkEvaluation(ArraySort(a5), Seq(null, null)) + checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { @@ -90,6 +381,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq( + StructField("a", IntegerType, true))))) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) @@ -104,6 +397,252 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContains(a4, Literal.create(create_row(1), StructType(Seq( + StructField("a", IntegerType, false))))), true) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(0), StructType(Seq( + StructField("a", IntegerType, false))))), false) + + // binary + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val be = Literal.create(Array[Byte](1, 2), BinaryType) + val nullBinary = Literal.create(null, BinaryType) + + checkEvaluation(ArrayContains(b0, be), true) + checkEvaluation(ArrayContains(b1, be), false) + checkEvaluation(ArrayContains(b0, nullBinary), null) + checkEvaluation(ArrayContains(b2, be), null) + checkEvaluation(ArrayContains(b3, be), true) + + // complex data types + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayContains(aa0, aae), true) + checkEvaluation(ArrayContains(aa1, aae), false) + } + + test("ArraysOverlap") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) + val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + + val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + checkEvaluation(ArraysOverlap(a0, a1), true) + checkEvaluation(ArraysOverlap(a0, a2), null) + checkEvaluation(ArraysOverlap(a1, a2), true) + checkEvaluation(ArraysOverlap(a1, a3), false) + checkEvaluation(ArraysOverlap(a0, emptyIntArray), false) + checkEvaluation(ArraysOverlap(a2, emptyIntArray), false) + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + + checkEvaluation(ArraysOverlap(a4, a5), true) + checkEvaluation(ArraysOverlap(a4, a6), null) + checkEvaluation(ArraysOverlap(a5, a6), false) + + // null handling + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + checkEvaluation(ArraysOverlap( + emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false) + checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) + checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) + checkEvaluation(ArraysOverlap( + Literal.create(Seq(null), ArrayType(IntegerType)), + Literal.create(Seq(null), ArrayType(IntegerType))), null) + + // arrays of binaries + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + + checkEvaluation(ArraysOverlap(b0, b1), true) + checkEvaluation(ArraysOverlap(b0, b2), false) + + // arrays of complex data types + val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), + ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")), + ArrayType(ArrayType(StringType))) + val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(ArraysOverlap(aa0, aa1), true) + checkEvaluation(ArraysOverlap(aa0, aa2), false) + + // null handling with complex datatypes + val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false) + checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null) + checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null) + } + + test("Slice") { + val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) + + checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) + checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) + checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) + checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") + checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) + checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) + checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) + + checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) + checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) + checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) + } + + test("ArrayJoin") { + def testArrays( + arrays: Seq[Expression], + nullReplacement: Option[Expression], + expected: Seq[String]): Unit = { + assert(arrays.length == expected.length) + arrays.zip(expected).foreach { case (arr, exp) => + checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp) + } + } + + val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)), + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)), + Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a"), ArrayType(StringType))) + + val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a") + val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a") + testArrays(arrays, None, withoutNullReplacement) + testArrays(arrays, Some(Literal("NULL")), withNullReplacement) + + checkEvaluation(ArrayJoin( + Literal.create(null, ArrayType(StringType)), Literal(","), None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(null, StringType), + None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal(","), + Some(Literal.create(null, StringType))), null) + } + + test("ArraysZip") { + val literals = Seq( + Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)), + Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)), + Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)), + Literal.create(Seq("a", null, "c"), ArrayType(StringType)), + Literal.create(Seq(null, false, true), ArrayType(BooleanType)), + Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)), + Literal.create(Seq(), ArrayType(NullType)), + Literal.create(Seq(null), ArrayType(NullType)), + Literal.create(Seq(192.toByte), ArrayType(ByteType)), + Literal.create( + Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))), + Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType)) + ) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1))), + List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(2))), + List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(3))), + List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(4))), + List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(5))), + List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(6))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(7))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))), + List( + Row(9001, null, -1, "a"), + Row(9002, 1L, -3, null), + Row(9003, null, 900, "c"), + Row(null, 4L, null, null), + Row(null, 11L, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))), + List( + Row(null, 1.1, null, null, 192.toByte), + Row(false, null, null, null, null), + Row(true, 1.3, null, null, null), + Row(null, null, null, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(9), literals(0))), + List( + Row(List(1, 2, 3), 9001), + Row(null, 9002), + Row(List(4, 5), 9003), + Row(List(1, null, 3), null))) + + checkEvaluation(ArraysZip(Seq(literals(7), literals(10))), + List(Row(null, Array[Byte](1.toByte, 5.toByte)))) + + val longLiteral = + Literal.create((0 to 1000).toSeq, ArrayType(IntegerType)) + + checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)), + List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++ + (3 to 1000).map { Row(null, _) }.toList) + + val manyLiterals = (0 to 1000).map { _ => + Literal.create(Seq(1), ArrayType(IntegerType)) + }.toSeq + + val numbers = List( + Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*), + Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*)) + checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals), + List(numbers(0), numbers(1), numbers(2), numbers(3))) + + checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null) + checkEvaluation(ArraysZip(Seq()), List()) } test("Array Min") { @@ -126,6 +665,292 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + test("Sequence of numbers") { + // test null handling + + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType)), null) + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(1L), Literal(null, LongType)), null) + + // test sequence boundaries checking + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), + EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(0)), EmptyRow, "boundaries: 2 to 1 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(1)), EmptyRow, "boundaries: 2 to 1 by 1") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1") + + // test sequence with one element (zero step or equal start and stop) + + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(0)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(2), Literal(2)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(0), Literal(-2)), Seq(1)) + + // test sequence of different integral types (ascending and descending) + + checkEvaluation(new Sequence(Literal(1L), Literal(3L), Literal(1L)), Seq(1L, 2L, 3L)) + checkEvaluation(new Sequence(Literal(-3), Literal(3), Literal(3)), Seq(-3, 0, 3)) + checkEvaluation( + new Sequence(Literal(3.toShort), Literal(-3.toShort), Literal(-3.toShort)), + Seq(3.toShort, 0.toShort, -3.toShort)) + checkEvaluation( + new Sequence(Literal(-1.toByte), Literal(-3.toByte), Literal(-1.toByte)), + Seq(-1.toByte, -2.toByte, -3.toByte)) + } + + test("Sequence of timestamps") { + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:01")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2017-12-31 23:59:59")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-03-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + Seq( + Timestamp.valueOf("2018-03-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 day").negate())), + Seq( + Timestamp.valueOf("2018-03-03 00:00:00"), + Timestamp.valueOf("2018-02-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-31 00:00:00")), + Literal(Timestamp.valueOf("2018-04-30 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-31 00:00:00"), + Timestamp.valueOf("2018-02-28 00:00:00"), + Timestamp.valueOf("2018-03-31 00:00:00"), + Timestamp.valueOf("2018-04-30 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 second"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:01"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:04:06")), + Literal(CalendarInterval.fromString("interval 1 month 2 minutes 3 seconds"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:02:03"), + Timestamp.valueOf("2018-03-01 00:04:06"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2023-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2022-04-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2022-04-01 00:00:00")), + Literal(Timestamp.valueOf("2017-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5").negate())), + Seq( + Timestamp.valueOf("2022-04-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + } + + test("Sequence on DST boundaries") { + val timeZone = TimeZone.getTimeZone("Europe/Prague") + val dstOffset = timeZone.getDSTSavings + + def noDST(t: Timestamp): Timestamp = new Timestamp(t.getTime - dstOffset) + + DateTimeTestUtils.withDefaultTimeZone(timeZone) { + // Spring time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-25 01:30:00")), + Literal(Timestamp.valueOf("2018-03-25 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-03-25 01:30:00"), + Timestamp.valueOf("2018-03-25 03:00:00"), + Timestamp.valueOf("2018-03-25 03:30:00"))) + + // Autumn time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-10-28 01:30:00")), + Literal(Timestamp.valueOf("2018-10-28 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-10-28 01:30:00"), + noDST(Timestamp.valueOf("2018-10-28 02:00:00")), + noDST(Timestamp.valueOf("2018-10-28 02:30:00")), + Timestamp.valueOf("2018-10-28 02:00:00"), + Timestamp.valueOf("2018-10-28 02:30:00"), + Timestamp.valueOf("2018-10-28 03:00:00"), + Timestamp.valueOf("2018-10-28 03:30:00"))) + } + } + + test("Sequence of dates") { + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-05")), + Literal(CalendarInterval.fromString("interval 2 days"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-05"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-03-01")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-31")), + Literal(Date.valueOf("2018-04-30")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-31"), + Date.valueOf("2018-02-28"), + Date.valueOf("2018-03-31"), + Date.valueOf("2018-04-30"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2023-01-01")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2019-06-01"), + Date.valueOf("2020-11-01"), + Date.valueOf("2022-04-01"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-02")), + Literal(Date.valueOf("1970-01-01")), + Literal(CalendarInterval.fromString("interval 1 day"))), + EmptyRow, "sequence boundaries: 1 to 0 by 1") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-01")), + Literal(Date.valueOf("1970-02-01")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + EmptyRow, + s"sequence boundaries: 0 to 2678400000000 by -${28 * CalendarInterval.MICROS_PER_DAY}") + } + } + + test("Sequence with default step") { + // +/- 1 for integral type + checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3)) + checkEvaluation(new Sequence(Literal(3), Literal(1)), Seq(3, 2, 1)) + + // +/- 1 day for timestamps + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-03 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-03 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-03 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + // +/- 1 day for dates + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-03"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-03"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-03")), + Literal(Date.valueOf("2018-01-01"))), + Seq( + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-01"))) + } + test("Reverse") { // Primitive-type elements val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) @@ -190,6 +1015,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayPosition(aa0, aae), 1L) + checkEvaluation(ArrayPosition(aa1, aae), 0L) } test("elementAt") { @@ -227,7 +1060,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(null, MapType(StringType, StringType)) - checkEvaluation(ElementAt(m0, Literal(1.0)), null) + assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) checkEvaluation(ElementAt(m0, Literal("d")), null) @@ -238,15 +1071,27 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m0, Literal("c")), null) checkEvaluation(ElementAt(m2, Literal("a")), null) + + // test binary type as keys + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) } test("Concat") { // Primitive-type elements - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) - val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) - val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) - val ai4 = Literal.create(null, ArrayType(IntegerType)) + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType, containsNull = true)) + val ai4 = Literal.create(null, ArrayType(IntegerType, containsNull = false)) checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) @@ -259,14 +1104,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(ai4, ai0)), null) // Non-primitive-type elements - val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) - val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) - val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) - val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) - val as4 = Literal.create(null, ArrayType(StringType)) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType, containsNull = true)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType, containsNull = true)) + val as4 = Literal.create(null, ArrayType(StringType, containsNull = false)) - val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) - val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa2 = Literal.create(Seq(Seq("g", null), null), + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) @@ -279,5 +1128,617 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(as4, as0)), null) checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + + assert(Concat(Seq(ai0, ai1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(ai0, ai2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(as0, as1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(as0, as2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(aa0, aa1)).dataType === + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + assert(Concat(Seq(aa0, aa2)).dataType === + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) + + // force split expressions for input in generated code + checkEvaluation(Concat(Seq.fill(100)(ai0)), Seq.fill(100)(Seq(1, 2, 3)).flatten) + } + + test("Flatten") { + // Primitive-type test cases + val intArrayType = ArrayType(ArrayType(IntegerType)) + + // Main test cases (primitive type) + val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType) + val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType) + + checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(aim2), Seq(1, 2, 3)) + + // Test cases with an empty array (primitive type) + val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType) + val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType) + val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType) + val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType) + val aie5 = Literal.create(Seq(Seq.empty), intArrayType) + val aie6 = Literal.create(Seq.empty, intArrayType) + + checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie4), Seq.empty) + checkEvaluation(Flatten(aie5), Seq.empty) + checkEvaluation(Flatten(aie6), Seq.empty) + + // Test cases with null elements (primitive type) + val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType) + val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType) + val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType) + + checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null)) + checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null)) + checkEvaluation(Flatten(ain3), Seq(null, null, null, null)) + + // Test cases with a null array (primitive type) + val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType) + val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType) + val aia3 = Literal.create(Seq(null), intArrayType) + val aia4 = Literal.create(null, intArrayType) + + checkEvaluation(Flatten(aia1), null) + checkEvaluation(Flatten(aia2), null) + checkEvaluation(Flatten(aia3), null) + checkEvaluation(Flatten(aia4), null) + + // Non-primitive-type test cases + val strArrayType = ArrayType(ArrayType(StringType)) + val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) + + // Main test cases (non-primitive type) + val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) + val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) + val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) + + checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(asm2), Seq("a", "b")) + checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) + + // Test cases with an empty array (non-primitive type) + val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) + val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) + val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) + val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType) + val ase5 = Literal.create(Seq(Seq.empty), strArrayType) + val ase6 = Literal.create(Seq.empty, strArrayType) + + checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase4), Seq.empty) + checkEvaluation(Flatten(ase5), Seq.empty) + checkEvaluation(Flatten(ase6), Seq.empty) + + // Test cases with null elements (non-primitive type) + val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) + val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) + val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) + + checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null)) + checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) + checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) + + // Test cases with a null array (non-primitive type) + val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) + val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) + val asa3 = Literal.create(Seq(null), strArrayType) + val asa4 = Literal.create(null, strArrayType) + + checkEvaluation(Flatten(asa1), null) + checkEvaluation(Flatten(asa2), null) + checkEvaluation(Flatten(asa3), null) + checkEvaluation(Flatten(asa4), null) + } + + test("ArrayRepeat") { + val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType)) + + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi")) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi")) + checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true)) + checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1)) + checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2)) + checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null)) + checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null)) + checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2))) + checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) + } + + test("Array remove") { + val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) + val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a4 = Literal.create(null, ArrayType(StringType)) + val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType)) + val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + + checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) + checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) + checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null) + + checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c")) + checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) + + checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) + checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null) + + checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) + + checkEvaluation(ArrayRemove(a4, Literal("a")), null) + + checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) + checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + + val dataToRemove1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation(ArrayRemove(b0, dataToRemove1), + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2))) + checkEvaluation(ArrayRemove(b0, nullBinary), null) + checkEvaluation(ArrayRemove(b1, dataToRemove1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayRemove(b2, dataToRemove1), Seq[Array[Byte]](null, Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType))) + val dataToRemove2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayRemove(c0, dataToRemove2), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } + + test("Array Distinct") { + val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234), + ArrayType(DoubleType)) + val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), + ArrayType(FloatType)) + + checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) + checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) + checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) + checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) + checkEvaluation(new ArrayDistinct(a4), Seq(null)) + checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) + checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) + checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2), + null, Array[Byte](5, 6), null), ArrayType(BinaryType)) + + checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2))) + checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null, + Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } + + test("Array Union") { + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false)) + val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) + val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) + val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true)) + val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false)) + val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false)) + val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) + checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) + checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayUnion(abl0, abl1), Seq[Boolean](true, false)) + checkEvaluation(ArrayUnion(ab0, ab1), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(as0, as1), Seq[Short](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(af0, af1), Seq[Float](1.1F, 2.2F, 3.3F, 4.4F)) + checkEvaluation(ArrayUnion(ad0, ad1), Seq[Double](1.1, 2.2, 3.3, 4.4)) + + checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) + checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L)) + checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L)) + + checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) + checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) + + checkEvaluation(ArrayUnion(a30, a30), Seq(null)) + checkEvaluation(ArrayUnion(a20, a31), null) + checkEvaluation(ArrayUnion(a31, a20), null) + + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]]( + Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val b6 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayUnion(b0, b1), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b0, b2), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) + checkEvaluation(ArrayUnion(b3, b0), + Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayUnion(aa0, aa1), + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1))) + + assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } + + test("Shuffle") { + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType, containsNull = true)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType, containsNull = true)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType, containsNull = false)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType, containsNull = false)) + val ai7 = Literal.create(null, ArrayType(IntegerType, containsNull = true)) + + checkEvaluation(Shuffle(ai0, Some(0)), Seq(4, 1, 2, 3, 5)) + checkEvaluation(Shuffle(ai1, Some(0)), Seq(3, 1, 2)) + checkEvaluation(Shuffle(ai2, Some(0)), Seq(3, null, 1, null)) + checkEvaluation(Shuffle(ai3, Some(0)), Seq(null, 2, null, 4)) + checkEvaluation(Shuffle(ai4, Some(0)), Seq(null, null, null)) + checkEvaluation(Shuffle(ai5, Some(0)), Seq(1)) + checkEvaluation(Shuffle(ai6, Some(0)), Seq.empty) + checkEvaluation(Shuffle(ai7, Some(0)), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("a", "b", "c", "d"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType, containsNull = true)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType, containsNull = true)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType, containsNull = true)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = false)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType, containsNull = false)) + val as7 = Literal.create(null, ArrayType(StringType, containsNull = true)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Shuffle(as0, Some(0)), Seq("d", "a", "b", "c")) + checkEvaluation(Shuffle(as1, Some(0)), Seq("c", "a", "b")) + checkEvaluation(Shuffle(as2, Some(0)), Seq("c", null, "a", null)) + checkEvaluation(Shuffle(as3, Some(0)), Seq(null, "b", null, "d")) + checkEvaluation(Shuffle(as4, Some(0)), Seq(null, null, null)) + checkEvaluation(Shuffle(as5, Some(0)), Seq("a")) + checkEvaluation(Shuffle(as6, Some(0)), Seq.empty) + checkEvaluation(Shuffle(as7, Some(0)), null) + checkEvaluation(Shuffle(aa, Some(0)), Seq(Seq("e"), Seq("a", "b"), Seq("c", "d"))) + + val r = new Random(1234) + val seed1 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) === + evaluateWithoutCodegen(Shuffle(ai0, seed1))) + assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) === + evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1))) + assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) === + evaluateWithUnsafeProjection(Shuffle(ai0, seed1))) + + val seed2 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) !== + evaluateWithoutCodegen(Shuffle(ai0, seed2))) + assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) !== + evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed2))) + assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !== + evaluateWithUnsafeProjection(Shuffle(ai0, seed2))) + + val shuffle = Shuffle(ai0, seed1) + assert(shuffle.fastEquals(shuffle)) + assert(!shuffle.fastEquals(Shuffle(ai0, seed1))) + assert(!shuffle.fastEquals(shuffle.freshCopy())) + assert(!shuffle.fastEquals(Shuffle(ai0, seed2))) + } + + test("Array Except") { + val a00 = Literal.create(Seq(1, 2, 4, 3), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 4, 2), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, 1), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L, 3L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 4L, 2L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, 1L), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c", "d"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", "c"), ArrayType(StringType, false)) + val a24 = Literal.create(Seq("c", null, "a", "f", "c"), ArrayType(StringType, true)) + val a25 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, true)) + val a26 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayExcept(a00, a01), Seq(1, 3)) + checkEvaluation(ArrayExcept(a02, a01), Seq(1)) + checkEvaluation(ArrayExcept(a02, a02), Seq.empty) + checkEvaluation(ArrayExcept(a02, a03), Seq(1)) + checkEvaluation(ArrayExcept(a04, a02), Seq(null, 5)) + checkEvaluation(ArrayExcept(a04, a05), Seq(1, 5)) + checkEvaluation(ArrayExcept(a04, a06), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayExcept(a06, a04), Seq.empty) + checkEvaluation(ArrayExcept(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayExcept(ab0, ab1), Seq[Byte](1, 3)) + checkEvaluation(ArrayExcept(as0, as1), Seq[Short](1, 3)) + checkEvaluation(ArrayExcept(af0, af1), Seq[Float](1.1F, 3.3F)) + checkEvaluation(ArrayExcept(ad0, ad1), Seq[Double](1.1, 3.3)) + + checkEvaluation(ArrayExcept(a10, a11), Seq(1L, 3L)) + checkEvaluation(ArrayExcept(a12, a11), Seq(1L)) + checkEvaluation(ArrayExcept(a12, a12), Seq.empty) + checkEvaluation(ArrayExcept(a12, a13), Seq(1L)) + checkEvaluation(ArrayExcept(a14, a12), Seq(null, 5L)) + checkEvaluation(ArrayExcept(a14, a15), Seq(1L, 5L)) + checkEvaluation(ArrayExcept(a14, a16), Seq(1L, 2L, null, 4L, 5L)) + checkEvaluation(ArrayExcept(a16, a14), Seq.empty) + + checkEvaluation(ArrayExcept(a20, a21), Seq("b", "d")) + checkEvaluation(ArrayExcept(a22, a21), Seq("b")) + checkEvaluation(ArrayExcept(a22, a22), Seq.empty) + checkEvaluation(ArrayExcept(a22, a23), Seq("b")) + checkEvaluation(ArrayExcept(a24, a22), Seq(null, "f")) + checkEvaluation(ArrayExcept(a24, a25), Seq("c", "f")) + checkEvaluation(ArrayExcept(a24, a26), Seq("c", null, "a", "f")) + checkEvaluation(ArrayExcept(a26, a24), Seq.empty) + + checkEvaluation(ArrayExcept(a30, a30), Seq.empty) + checkEvaluation(ArrayExcept(a20, a31), null) + checkEvaluation(ArrayExcept(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](7, 8)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayExcept(b0, b1), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b1, b0), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b0, b2), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b2, b0), Seq.empty) + checkEvaluation(ArrayExcept(b2, b3), Seq[Array[Byte]](Array[Byte](1, 2))) + checkEvaluation(ArrayExcept(b3, b2), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayExcept(b3, b4), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b4, b3), Seq.empty) + checkEvaluation(ArrayExcept(b4, b5), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayExcept(b5, b4), Seq.empty) + checkEvaluation(ArrayExcept(b4, arrayWithBinaryNull), Seq[Array[Byte]](Array[Byte](3, 4))) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayExcept(aa0, aa1), Seq[Seq[Int]](Seq[Int](1, 2))) + checkEvaluation(ArrayExcept(aa1, aa0), Seq[Seq[Int]](Seq[Int](2, 1))) + + assert(ArrayExcept(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a04, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } + + test("Array Intersect") { + val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 1, 4), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, null), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1, null), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, false, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, containsNull = false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, containsNull = false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, containsNull = false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, containsNull = false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 1L, 4L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L, 4L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, null), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L, null), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", null, "f"), ArrayType(StringType, true)) + val a24 = Literal.create(Seq("b", null, "a", "g", null), ArrayType(StringType, true)) + val a25 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayIntersect(a00, a01), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a01, a00), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a02, a03), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a03, a02), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a00, a04), Seq(1, 2, 4)) + checkEvaluation(ArrayIntersect(a04, a05), Seq(2, null, 4)) + checkEvaluation(ArrayIntersect(a02, a06), Seq.empty) + checkEvaluation(ArrayIntersect(a06, a04), Seq.empty) + checkEvaluation(ArrayIntersect(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayIntersect(ab0, ab1), Seq[Byte](2)) + checkEvaluation(ArrayIntersect(as0, as1), Seq[Short](2)) + checkEvaluation(ArrayIntersect(af0, af1), Seq[Float](2.2F)) + checkEvaluation(ArrayIntersect(ad0, ad1), Seq[Double](2.2D)) + + checkEvaluation(ArrayIntersect(a10, a11), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a11, a10), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a12, a13), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a13, a12), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a14, a15), Seq(2L, null, 4L)) + checkEvaluation(ArrayIntersect(a12, a16), Seq.empty) + checkEvaluation(ArrayIntersect(a16, a14), Seq.empty) + + checkEvaluation(ArrayIntersect(a20, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a20), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a22, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a22), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a23, a24), Seq("a", null)) + checkEvaluation(ArrayIntersect(a24, a23), Seq(null, "a")) + checkEvaluation(ArrayIntersect(a24, a25), Seq.empty) + checkEvaluation(ArrayIntersect(a25, a24), Seq.empty) + + checkEvaluation(ArrayIntersect(a30, a30), Seq(null)) + checkEvaluation(ArrayIntersect(a20, a31), null) + checkEvaluation(ArrayIntersect(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArrayIntersect(b0, b1), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b1, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](5, 6))) + checkEvaluation(ArrayIntersect(b0, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b2, b3), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b3, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b3, b4), Seq[Array[Byte]](Array[Byte](3, 4), null)) + checkEvaluation(ArrayIntersect(b4, b3), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b4, b5), Seq.empty) + checkEvaluation(ArrayIntersect(b5, b4), Seq.empty) + checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq[Array[Byte]](null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayIntersect(aa0, aa1), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayIntersect(aa1, aa0), Seq[Seq[Int]](Seq[Int](3, 4))) + + assert(ArrayIntersect(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b4138ce366b3a..77aaf55480ec2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -144,6 +144,13 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) + + val array = CreateArray(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)))) + assert(array.dataType === + ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + checkEvaluation(array, Seq(intSeq, intSeq :+ null)) } test("CreateMap") { @@ -184,6 +191,62 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } + + val map = CreateMap(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(strSeq, ArrayType(StringType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)), + Literal.create(strSeq :+ null, ArrayType(StringType, containsNull = true)))) + assert(map.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = false)) + checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null))) + } + + test("MapFromArrays") { + def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25) + val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) + val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) + + val intArray = Literal.create(intSeq, ArrayType(IntegerType, false)) + val longArray = Literal.create(longSeq, ArrayType(LongType, false)) + val strArray = Literal.create(strSeq, ArrayType(StringType, false)) + + val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true)) + val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) + val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) + + val nullArray = Literal.create(null, ArrayType(StringType, false)) + + checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + + checkEvaluation( + MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation(MapFromArrays(nullArray, nullArray), null) + + intercept[RuntimeException] { + checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) + } + intercept[RuntimeException] { + checkEvaluation( + MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + } } test("CreateStruct") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index a099119732e25..f489d330cf453 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -113,6 +113,76 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } + test("if/case when - null flags of non-primitive types") { + val arrayWithNulls = Literal.create(Seq("a", null, "b"), ArrayType(StringType, true)) + val arrayWithoutNulls = Literal.create(Seq("c", "d"), ArrayType(StringType, false)) + val structWithNulls = Literal.create( + create_row(null, null), + StructType(Seq(StructField("a", IntegerType, true), StructField("b", StringType, true)))) + val structWithoutNulls = Literal.create( + create_row(1, "a"), + StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, false)))) + val mapWithNulls = Literal.create(Map(1 -> null), MapType(IntegerType, StringType, true)) + val mapWithoutNulls = Literal.create(Map(1 -> "a"), MapType(IntegerType, StringType, false)) + + val arrayIf1 = If(Literal.FalseLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf2 = If(Literal.FalseLiteral, arrayWithoutNulls, arrayWithNulls) + val arrayIf3 = If(Literal.TrueLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf4 = If(Literal.TrueLiteral, arrayWithoutNulls, arrayWithNulls) + val structIf1 = If(Literal.FalseLiteral, structWithNulls, structWithoutNulls) + val structIf2 = If(Literal.FalseLiteral, structWithoutNulls, structWithNulls) + val structIf3 = If(Literal.TrueLiteral, structWithNulls, structWithoutNulls) + val structIf4 = If(Literal.TrueLiteral, structWithoutNulls, structWithNulls) + val mapIf1 = If(Literal.FalseLiteral, mapWithNulls, mapWithoutNulls) + val mapIf2 = If(Literal.FalseLiteral, mapWithoutNulls, mapWithNulls) + val mapIf3 = If(Literal.TrueLiteral, mapWithNulls, mapWithoutNulls) + val mapIf4 = If(Literal.TrueLiteral, mapWithoutNulls, mapWithNulls) + + val arrayCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithoutNulls)), arrayWithNulls) + val arrayCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithoutNulls)), arrayWithNulls) + val structCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, structWithoutNulls)), structWithNulls) + val structCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, structWithoutNulls)), structWithNulls) + val mapCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, mapWithoutNulls)), mapWithNulls) + val mapCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, mapWithoutNulls)), mapWithNulls) + + def checkResult(expectedType: DataType, expectedValue: Any, result: Expression): Unit = { + assert(expectedType == result.dataType) + checkEvaluation(result, expectedValue) + } + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf1) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf2) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf4) + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen1) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen2) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen4) + } + test("case key when") { val row = create_row(null, 1, 2, "a", "b", "c") val c1 = 'a.int.at(0) @@ -139,7 +209,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), null, row) } - test("case key whn - internal pattern matching expects a List while apply takes a Seq") { + test("case key when - internal pattern matching expects a List while apply takes a Seq") { val indexedSeq = IndexedSeq(Literal(1), Literal(42), Literal(42), Literal(1)) val caseKeyWhaen = CaseKeyWhen(Literal(12), indexedSeq) assert(caseKeyWhaen.branches == diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 080ec487cfa6a..63b24fb9eb13a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -464,34 +464,47 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MonthsBetween( Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), - timeZoneId), - 3.94959677) + Literal.TrueLiteral, + timeZoneId = timeZoneId), 3.94959677) checkEvaluation( MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), - timeZoneId), - 0.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - timeZoneId), - -2.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), - timeZoneId), - 1.0) + Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), + Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), + Literal.FalseLiteral, + timeZoneId = timeZoneId), 3.9495967741935485) + + Seq(Literal.FalseLiteral, Literal.TrueLiteral). foreach { roundOff => + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 0.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), -2.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 1.0) + } val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) val tnull = Literal.create(null, TimestampType) - checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null) + checkEvaluation(MonthsBetween(t, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, t, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(tnull, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(t, t, Literal.create(null, BooleanType), timeZoneId = timeZoneId), null) checkConsistencyBetweenInterpretedAndCodegen( - (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId), - TimestampType, TimestampType) + (time1: Expression, time2: Expression, roundOff: Expression) => + MonthsBetween(time1, time2, roundOff, timeZoneId = timeZoneId), + TimestampType, TimestampType, BooleanType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala new file mode 100644 index 0000000000000..2352db405b1a8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.UUID + +import org.apache.spark.SparkFunSuite + +class ExprIdSuite extends SparkFunSuite { + + private val jvmId = UUID.randomUUID() + private val otherJvmId = UUID.randomUUID() + + test("hashcode independent of jvmId") { + val exprId1 = ExprId(12, jvmId) + val exprId2 = ExprId(12, otherJvmId) + assert(exprId1 != exprId2) + assert(exprId1.hashCode() == exprId2.hashCode()) + } + + test("equality should depend on both id and jvmId") { + val exprId1 = ExprId(1, jvmId) + val exprId2 = ExprId(1, jvmId) + assert(exprId1 == exprId2) + + val exprId3 = ExprId(1, jvmId) + val exprId4 = ExprId(2, jvmId) + assert(exprId3 != exprId4) + + val exprId5 = ExprId(1, jvmId) + val exprId6 = ExprId(1, otherJvmId) + assert(exprId5 != exprId6) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7e..6684e5ce18d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -78,6 +79,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) + case (result: InternalRow, expected: InternalRow) => + val st = dataType.asInstanceOf[StructType] + assert(result.numFields == st.length && expected.numFields == st.length) + st.zipWithIndex.forall { case (f, i) => + checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType) + } case (result: ArrayData, expected: ArrayData) => result.numElements == expected.numElements && { val et = dataType.asInstanceOf[ArrayType].elementType @@ -104,6 +111,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + expectedErrMsg: String): Unit = { + checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) + } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( expression: => Expression, inputRow: InternalRow, @@ -196,39 +209,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) - } - - protected def checkEvaluationWithUnsafeProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow, - factory: UnsafeProjectionCreator): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - - if (expected == null) { - if (!unsafeRow.isNullAt(0)) { - val expectedRow = InternalRow(expected, expected) - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") - } - } else { - val lit = InternalRow(expected, expected) - val expectedRow = - factory.create(Array(expression.dataType, expression.dataType)).apply(lit) - if (unsafeRow != expectedRow) { - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected, expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected, expected) + val expectedRow = + UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } } } } protected def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow, - factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { + inputRow: InternalRow = EmptyRow): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 64b65e2070ed6..7c7c4cccee253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression { override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = - s""" + code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index 12eddf557109f..3ccaa5976cc28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -41,7 +41,7 @@ class ExpressionSetSuite extends SparkFunSuite { // maxHash's hashcode is calculated based on this exprId's hashcode, so we set this // exprId's hashCode to this specific value to make sure maxHash's hashcode is // `Int.MaxValue` - override def hashCode: Int = -1030353449 + override def hashCode: Int = 1394598635 // We are implementing this equals() only because the style-checking rule "you should // implement equals and hashCode together" requires us to override def equals(obj: Any): Boolean = super.equals(obj) @@ -57,7 +57,7 @@ class ExpressionSetSuite extends SparkFunSuite { // minHash's hashcode is calculated based on this exprId's hashcode, so we set this // exprId's hashCode to this specific value to make sure minHash's hashcode is // `Int.MinValue` - override def hashCode: Int = 1407330692 + override def hashCode: Int = -462684520 // We are implementing this equals() only because the style-checking rule "you should // implement equals and hashCode together" requires us to override def equals(obj: Any): Boolean = super.equals(obj) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala new file mode 100644 index 0000000000000..e13f4d98295be --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -0,0 +1,614 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types._ + +class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private def createLambda( + dt: DataType, + nullable: Boolean, + f: Expression => Expression): Expression = { + val lv = NamedLambdaVariable("arg", dt, nullable) + val function = f(lv) + LambdaFunction(function, Seq(lv)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + f: (Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val function = f(lv1, lv2) + LambdaFunction(function, Seq(lv1, lv2)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + + private def validateBinding( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction => + assert(f.arguments.size === argInfo.size) + f.arguments.zip(argInfo).foreach { + case (arg, (dataType, nullable)) => + assert(arg.dataType === dataType) + assert(arg.nullable === nullable) + } + f + } + + def transform(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) + } + + def filter(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, et, cn, merge), + createLambda(zeroType, true, finish)) + .bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + test("ArrayTransform") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val plusOne: Expression => Expression = x => x + 1 + val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + + checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) + checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) + checkEvaluation(transform(transform(ai0, plusIndex), plusOne), Seq(2, 4, 6)) + checkEvaluation(transform(ai1, plusOne), Seq(2, null, 4)) + checkEvaluation(transform(ai1, plusIndex), Seq(1, null, 5)) + checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6)) + checkEvaluation(transform(ain, plusOne), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val repeatTwice: Expression => Expression = x => Concat(Seq(x, x)) + val repeatIndexTimes: (Expression, Expression) => Expression = (x, i) => StringRepeat(x, i) + + checkEvaluation(transform(as0, repeatTwice), Seq("aa", "bb", "cc")) + checkEvaluation(transform(as0, repeatIndexTimes), Seq("", "b", "cc")) + checkEvaluation(transform(transform(as0, repeatIndexTimes), repeatTwice), + Seq("", "bb", "cccc")) + checkEvaluation(transform(as1, repeatTwice), Seq("aa", null, "cc")) + checkEvaluation(transform(as1, repeatIndexTimes), Seq("", null, "cc")) + checkEvaluation(transform(transform(as1, repeatIndexTimes), repeatTwice), + Seq("", null, "cccc")) + checkEvaluation(transform(asn, repeatTwice), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, array => Cast(transform(array, plusOne), StringType)), + Seq("[2, 3, 4]", null, "[5, 6]")) + checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), + Seq("[1, 3, 5]", null, "[4, 6]")) + } + + test("MapFilter") { + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v + + checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1)) + checkEvaluation(mapFilter(mii1, kGreaterThanV), Map()) + checkEvaluation(mapFilter(miin, kGreaterThanV), null) + + val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull + + checkEvaluation(mapFilter(mii0, valueIsNull), Map()) + checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null)) + checkEvaluation(mapFilter(miin, valueIsNull), null) + + val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0), + MapType(StringType, IntegerType, valueContainsNull = false)) + val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null), + MapType(StringType, IntegerType, valueContainsNull = true)) + val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false)) + + val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v + + checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0)) + checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5)) + checkEvaluation(mapFilter(msin, isLengthOfKey), null) + + val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true)) + val mian = Literal.create( + null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + + val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3 + + checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mian, customFunc), null) + } + + test("ArrayFilter") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val isEven: Expression => Expression = x => x % 2 === 0 + val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + + checkEvaluation(filter(ai0, isEven), Seq(2)) + checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) + checkEvaluation(filter(ai1, isEven), Seq.empty) + checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3)) + checkEvaluation(filter(ain, isEven), null) + checkEvaluation(filter(ain, isNullOrOdd), null) + + val as0 = + Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val startsWithA: Expression => Expression = x => x.startsWith("a") + + checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2")) + checkEvaluation(filter(as1, startsWithA), Seq("a")) + checkEvaluation(filter(asn, startsWithA), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)), + Seq(Seq(1, 3), null, Seq(5))) + } + + test("ArrayExists") { + def exists(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val isEven: Expression => Expression = x => x % 2 === 0 + val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + + checkEvaluation(exists(ai0, isEven), true) + checkEvaluation(exists(ai0, isNullOrOdd), true) + checkEvaluation(exists(ai1, isEven), false) + checkEvaluation(exists(ai1, isNullOrOdd), true) + checkEvaluation(exists(ain, isEven), null) + checkEvaluation(exists(ain, isNullOrOdd), null) + + val as0 = + Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val startsWithA: Expression => Expression = x => x.startsWith("a") + + checkEvaluation(exists(as0, startsWithA), true) + checkEvaluation(exists(as1, startsWithA), false) + checkEvaluation(exists(asn, startsWithA), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, ix => exists(ix, isNullOrOdd)), + Seq(true, null, true)) + } + + test("ArrayAggregate") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai2 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + checkEvaluation(aggregate(ai0, 0, (acc, elem) => acc + elem, acc => acc * 10), 60) + checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40) + checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0) + checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + checkEvaluation(aggregate(as0, "", (acc, elem) => Concat(Seq(acc, elem))), "abc") + checkEvaluation(aggregate(as1, "", (acc, elem) => Concat(Seq(acc, coalesce(elem, "x")))), "axc") + checkEvaluation(aggregate(as2, "", (acc, elem) => Concat(Seq(acc, elem))), "") + checkEvaluation(aggregate(asn, "", (acc, elem) => Concat(Seq(acc, elem))), null) + + val aai = Literal.create(Seq[Seq[Integer]](Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + aggregate(aai, 0, + (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), + 15) + } + + test("TransformKeys") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val ai1 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1 + val plusValue: (Expression, Expression) => Expression = (k, v) => k + v + val modKey: (Expression, Expression) => Expression = (k, v) => k % 3 + + checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) + checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) + checkEvaluation( + transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) + checkEvaluation(transformKeys(ai0, modKey), + ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation( + transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) + checkEvaluation( + transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) + checkEvaluation(transformKeys(ai3, plusOne), null) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(null, + MapType(StringType, StringType, valueContainsNull = false)) + val as3 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val convertKeyToKeyLength: (Expression, Expression) => Expression = + (k, v) => Length(k) + 1 + + checkEvaluation( + transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) + checkEvaluation( + transformKeys(transformKeys(as0, concatValue), concatValue), + Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) + checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String]) + checkEvaluation( + transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength), + Map.empty[Int, String]) + checkEvaluation(transformKeys(as0, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + checkEvaluation(transformKeys(as1, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> null)) + checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) + checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String]) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) + } + + test("TransformValues") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val ai1 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1 + val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k + + checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4)) + checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4)) + checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformValues(ai3, plusOne), null) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val valueTypeUpdate: (Expression, Expression) => Expression = + (k, v) => Length(v) + 1 + + checkEvaluation( + transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx")) + checkEvaluation(transformValues(as0, valueTypeUpdate), + Map("a" -> 3, "bb" -> 3, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as0, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as1, concatValue), + Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx")) + checkEvaluation(transformValues(as1, valueTypeUpdate), + Map("a" -> 3, "bb" -> null, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as1, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String]) + checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int]) + checkEvaluation( + transformValues(transformValues(as2, concatValue), valueTypeUpdate), + Map.empty[String, Int]) + checkEvaluation(transformValues(as3, concatValue), null) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + } + + test("MapZipWith") { + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, _) = left.dataType + val MapType(_, vt2, _) = right.dataType + MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) + .bind(validateBinding) + } + + val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii4 = MapFromArrays( + Literal.create(Seq(2, 2), ArrayType(IntegerType, false)), + Literal.create(Seq(20, 200), ArrayType(IntegerType, false))) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => k * v1 * v2 + } + + checkEvaluation( + map_zip_with(mii0, mii1, multiplyKeyWithValues), + Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null)) + checkEvaluation( + map_zip_with(mii0, mii2, multiplyKeyWithValues), + Map(1 -> null, 2 -> -80, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii3, multiplyKeyWithValues), + Map(1 -> null, 2 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii4, multiplyKeyWithValues), + Map(1 -> null, 2 -> 800, 3 -> null)) + checkEvaluation( + map_zip_with(mii4, mii0, multiplyKeyWithValues), + Map(2 -> 800, 1 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, miin, multiplyKeyWithValues), + null) + assert(map_zip_with(mii0, mii1, multiplyKeyWithValues).dataType === + MapType(IntegerType, IntegerType, valueContainsNull = true)) + + val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss1 = Literal.create(Map("d" -> "b", "b" -> "d"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false)) + val mss4 = MapFromArrays( + Literal.create(Seq("a", "a"), ArrayType(StringType, false)), + Literal.create(Seq("a", "n"), ArrayType(StringType, false))) + val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) + + val concat: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => Concat(Seq(k, v1, v2)) + } + + checkEvaluation( + map_zip_with(mss0, mss1, concat), + Map("a" -> null, "b" -> "byd", "d" -> "dzb")) + checkEvaluation( + map_zip_with(mss1, mss2, concat), + Map("d" -> null, "b" -> "bdt", "c" -> null, "a" -> null)) + checkEvaluation( + map_zip_with(mss0, mss3, concat), + Map("a" -> null, "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mss4, concat), + Map("a" -> "axa", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss4, mss0, concat), + Map("a" -> "aax", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mssn, concat), + null) + assert(map_zip_with(mss0, mss1, concat).dataType === + MapType(StringType, StringType, valueContainsNull = true)) + + def b(data: Byte*): Array[Byte] = Array[Byte](data: _*) + + val mbb0 = Literal.create(Map(b(1, 2) -> b(4), b(2, 1) -> b(5), b(1, 3) -> b(8)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb1 = Literal.create(Map(b(2, 1) -> b(7), b(1, 2) -> b(3), b(1, 1) -> b(6)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null), + MapType(BinaryType, BinaryType, valueContainsNull = true)) + val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb4 = MapFromArrays( + Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)), + Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false))) + val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false)) + + checkEvaluation( + map_zip_with(mbb0, mbb1, concat), + Map(b(1, 2) -> b(1, 2, 4, 3), b(2, 1) -> b(2, 1, 5, 7), b(1, 3) -> null, b(1, 1) -> null)) + checkEvaluation( + map_zip_with(mbb1, mbb2, concat), + Map(b(2, 1) -> null, b(1, 2) -> b(1, 2, 3, 2), b(1, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb3, concat), + Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb4, concat), + Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb4, mbb0, concat), + Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbbn, concat), + null) + } + + test("ZipWith") { + def zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression) => Expression): Expression = { + val ArrayType(leftT, _) = left.dataType + val ArrayType(rightT, _) = right.dataType + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) + } + + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq[Integer](1, null), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val add: (Expression, Expression) => Expression = (x, y) => x + y + val plusOne: Expression => Expression = x => x + 1 + + checkEvaluation(zip_with(ai0, ai1, add), Seq(2, 4, 6, null)) + checkEvaluation(zip_with(ai3, ai2, add), Seq(2, null, null)) + checkEvaluation(zip_with(ai2, ai3, add), Seq(2, null, null)) + checkEvaluation(zip_with(ain, ain, add), null) + checkEvaluation(zip_with(ai1, ain, add), null) + checkEvaluation(zip_with(ain, ai1, add), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val concat: (Expression, Expression) => Expression = (x, y) => Concat(Seq(x, y)) + + checkEvaluation(zip_with(as0, as1, concat), Seq("aa", null, "cc")) + checkEvaluation(zip_with(as0, as2, concat), Seq("aa", null, null)) + + val aai1 = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + val aai2 = Literal.create(Seq(Seq(1, 2, 3)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + zip_with(aai1, aai2, (a1, a2) => + Cast(zip_with(transform(a1, plusOne), transform(a2, plusOne), add), StringType)), + Seq("[4, 6, 8]", null, null)) + checkEvaluation(zip_with(aai1, aai1, (a1, a2) => Cast(transform(a1, plusOne), StringType)), + Seq("[2, 3, 4]", null, "[5, 6]")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 7812319756eae..04f1c8ce0b83d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -688,24 +688,28 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { val input = """{ - | "a": 1, - | "c": "foo" - |} - |""".stripMargin + | "a": 1, + | "c": "foo" + |} + |""".stripMargin val jsonSchema = new StructType() .add("a", LongType, nullable = false) .add("b", StringType, nullable = false) .add("c", StringType, nullable = false) val output = InternalRow(1L, null, UTF8String.fromString("foo")) - checkEvaluation( - JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), - output - ) - val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) - .dataType + val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema assert(schemaToCompare == schema) } } } + + test("SPARK-24709: infer schema of json strings") { + checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct") + checkEvaluation( + SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), + "struct,col1:struct>") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index a9e0eb0e377a6..86f80fe66d28b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -219,4 +219,11 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) } + + test("SPARK-24571: char literals") { + checkEvaluation(Literal('X'), "X") + checkEvaluation(Literal.create('0'), "0") + checkEvaluation(Literal('\u0000'), "\u0000") + checkEvaluation(Literal.create('\n'), "\n") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 424c3a4696077..6e07f7a59b730 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -86,6 +86,13 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) } + + val coalesce = Coalesce(Seq( + Literal.create(null, ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(coalesce.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(coalesce, Seq(1, 2, 3)) } test("SPARK-16602 Nvl should support numeric-string cases") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bcd035c1eba0b..b0af9e07d1d1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ +import scala.language.existentials import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -37,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -80,10 +82,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) checkEvaluationWithUnsafeProjection( - structEncoder.serializer.head, - structExpected, - structInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + structEncoder.serializer.head, structExpected, structInputRow) // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] @@ -91,10 +90,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) checkEvaluationWithUnsafeProjection( - arrayEncoder.serializer.head, - arrayExpected, - arrayInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -108,10 +104,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) checkEvaluationWithUnsafeProjection( - mapEncoder.serializer.head, - mapExpected, - mapInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + mapEncoder.serializer.head, mapExpected, mapInputRow) } test("SPARK-23582: StaticInvoke should support interpreted execution") { @@ -222,7 +215,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.fromObject(new java.util.LinkedList[Int]), Map("nonexisting" -> Literal(1))) checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), """A method named "nonexisting" is not declared in any enclosing class """ + "nor any supertype") @@ -286,8 +278,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluationWithUnsafeProjection( expr, expected, - inputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + inputRow) } checkEvaluationWithOptimization(expr, expected, inputRow) } @@ -296,7 +287,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => - checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) } // If an input row or a field are null, a runtime exception will be thrown @@ -472,6 +463,140 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val deserializer = toMapExpr.copy(inputData = Literal.create(data)) checkObjectExprEvaluation(deserializer, expected = data) } + + test("SPARK-23595 ValidateExternalType should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + Seq( + (true, BooleanType), + (2.toByte, ByteType), + (5.toShort, ShortType), + (23, IntegerType), + (61L, LongType), + (1.0f, FloatType), + (10.0, DoubleType), + ("abcd".getBytes, BinaryType), + ("abcd", StringType), + (BigDecimal.valueOf(10), DecimalType.IntDecimal), + (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), + (Array(3, 2, 1), ArrayType(IntegerType)) + ).foreach { case (input, dt) => + val validateType = ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) + } + + checkExceptionInExpression[RuntimeException]( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + InternalRow.fromSeq(Seq(Row(1))), + "java.lang.Integer is not a valid external type for schema of double") + } + + private def javaMapSerializerFor( + keyClazz: Class[_], + valueClazz: Class[_])(inputObject: Expression): Expression = { + + def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match { + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + ExternalMapToCatalyst( + inputObject, + ObjectType(keyClazz), + kvSerializerFor(_, keyClazz), + keyNullable = true, + ObjectType(valueClazz), + kvSerializerFor(_, valueClazz), + valueNullable = true + ) + } + + private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = { + import org.apache.spark.sql.catalyst.ScalaReflection._ + + val curId = new java.util.concurrent.atomic.AtomicInteger() + + def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression = + localTypeOf[V].dealias match { + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + case _ => + inputObject + } + + ExternalMapToCatalyst( + inputObject, + dataTypeFor[T], + kvSerializerFor[T], + keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive, + dataTypeFor[U], + kvSerializerFor[U], + valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive + ) + } + + test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") { + // Simple test + val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(0, "v0") + put(1, "v1") + put(2, null) + put(3, "v3") + } + } + val expected = CatalystTypeConverters.convertToCatalyst(scalaMap) + + // Java Map + val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMap)) + checkEvaluation(serializer1, expected) + + // Scala Map + val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap)) + checkEvaluation(serializer2, expected) + + // NULL key test + val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( + null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(null, "v0") + put(1, "v1") + } + } + + // Java Map + val serializer3 = + javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMapHasNullKey)) + checkExceptionInExpression[RuntimeException]( + serializer3, EmptyRow, "Cannot use null as map key!") + + // Scala Map + val serializer4 = scalaMapSerializerFor[java.lang.Integer, String]( + Literal.fromObject(scalaMapHasNullKey)) + + checkExceptionInExpression[RuntimeException]( + serializer4, EmptyRow, "Cannot use null as map key!") + } } class TestBean extends Serializable { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1bfd180ae4393..ac76b17ef4761 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -449,4 +449,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) } + + test("Interpreted Predicate should initialize nondeterministic expressions") { + val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + interpreted.initialize(0) + assert(interpreted.eval(new UnsafeRow())) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala new file mode 100644 index 0000000000000..cc2e2a993d629 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ + +class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("SortPrefix") { + val b1 = Literal.create(false, BooleanType) + val b2 = Literal.create(true, BooleanType) + val i1 = Literal.create(20132983, IntegerType) + val i2 = Literal.create(-20132983, IntegerType) + val l1 = Literal.create(20132983, LongType) + val l2 = Literal.create(-20132983, LongType) + val millis = 1524954911000L; + // Explicitly choose a time zone, since Date objects can create different values depending on + // local time zone of the machine on which the test is running + val oldDefaultTZ = TimeZone.getDefault + val d1 = try { + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + Literal.create(new java.sql.Date(millis), DateType) + } finally { + TimeZone.setDefault(oldDefaultTZ) + } + val t1 = Literal.create(new Timestamp(millis), TimestampType) + val f1 = Literal.create(0.7788229f, FloatType) + val f2 = Literal.create(-0.7788229f, FloatType) + val db1 = Literal.create(0.7788229d, DoubleType) + val db2 = Literal.create(-0.7788229d, DoubleType) + val s1 = Literal.create("T", StringType) + val s2 = Literal.create("This is longer than 8 characters", StringType) + val bin1 = Literal.create(Array[Byte](12), BinaryType) + val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]), + BinaryType) + val dec1 = Literal(Decimal(20132983L, 10, 2)) + val dec2 = Literal(Decimal(20132983L, 19, 2)) + val dec3 = Literal(Decimal(20132983L, 21, 2)) + val list1 = Literal(List(1, 2), ArrayType(IntegerType)) + val nullVal = Literal.create(null, IntegerType) + + checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L) + checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L) + checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L) + // For some reason, the Literal.create code gives us the number of days since the epoch + checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L) + checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000) + checkEvaluation(SortPrefix(SortOrder(f1, Ascending)), + DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(f2, Ascending)), + DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(db1, Ascending)), + DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(db2, Ascending)), + DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(s1, Ascending)), + StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(s2, Ascending)), + StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)), + BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)), + BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L) + checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)), + DoublePrefixComparator.computePrefix(201329.83d)) + checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index f1a6f9b8889fa..aa334e040d5fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "15,159,339,180,002,773.2778") checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(12831273.23481d), + Literal("###,###,###,###,###.###")), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")), + "123,123,324,123") + checkEvaluation( + FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)), + Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null) + assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")), + "12,332") + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, StringType)), null) + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("find in set") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index c48730bd9d1cc..1fa185cc77ebb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -30,7 +30,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite { } val b1 = a.withName("name2").withExprId(id) val b2 = a.withExprId(id) - val b3 = a.withQualifier(Some("qualifierName")) + val b3 = a.withQualifier(Seq("qualifierName")) assert(b1 != b2) assert(a != b1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index 351d4d0c2eac9..d46135c02bc01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -77,6 +77,19 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva } } + test("SPARK-21590: Start time works with negative values and return microseconds") { + val validDuration = "10 minutes" + for ((text, seconds) <- Seq( + ("-10 seconds", -10000000), // -1e7 + ("-1 minute", -60000000), + ("-1 hour", -3600000000L))) { // -6e7 + assert(TimeWindow(Literal(10L), validDuration, validDuration, "interval " + text).startTime + === seconds) + assert(TimeWindow(Literal(10L), validDuration, validDuration, text).startTime + === seconds) + } + } + private val parseExpression = PrivateMethod[Long]('parseExpression) test("parse sql expression for duration in microseconds - string") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c07da122cd7b8..5a646d9a850ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,25 +24,30 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - private def testWithFactory( - name: String)( - f: UnsafeProjectionCreator => Unit): Unit = { - test(name) { - f(UnsafeProjection) - f(InterpretedUnsafeProjection) + private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } } } - testWithFactory("basic conversion with only primitive types") { factory => + testBothCodegenAndInterpreted("basic conversion with only primitive types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) @@ -79,7 +84,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - testWithFactory("basic conversion with primitive, string and binary types") { factory => + testBothCodegenAndInterpreted("basic conversion with primitive, string and binary types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = factory.create(fieldTypes) @@ -98,7 +104,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => + testBothCodegenAndInterpreted( + "basic conversion with primitive, string, date and timestamp types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = factory.create(fieldTypes) @@ -127,7 +135,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - testWithFactory("null handling") { factory => + testBothCodegenAndInterpreted("null handling") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -248,7 +257,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - testWithFactory("NaN canonicalization") { factory => + testBothCodegenAndInterpreted("NaN canonicalization") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -263,7 +273,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - testWithFactory("basic conversion with struct type") { factory => + testBothCodegenAndInterpreted("basic conversion with struct type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) @@ -325,7 +336,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - testWithFactory("basic conversion with array type") { factory => + testBothCodegenAndInterpreted("basic conversion with array type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) @@ -355,7 +367,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - testWithFactory("basic conversion with map type") { factory => + testBothCodegenAndInterpreted("basic conversion with map type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) @@ -401,7 +414,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - testWithFactory("basic conversion with struct and array") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and array") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) @@ -440,7 +454,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with struct and map") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) @@ -486,7 +501,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with array and map") { factory => + testBothCodegenAndInterpreted("basic conversion with array and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala index 85682cf6ea670..d2862c8f41d1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers} import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.SparkSubmitSuite @@ -39,7 +39,7 @@ class BufferHolderSparkSubmitSuite val argsForSparkSubmit = Seq( "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), "--name", "SPARK-22222", - "--master", "local-cluster[2,1,1024]", + "--master", "local-cluster[1,1,4096]", "--driver-memory", "4g", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", @@ -49,28 +49,36 @@ class BufferHolderSparkSubmitSuite } } -object BufferHolderSparkSubmitSuite { +object BufferHolderSparkSubmitSuite extends Assertions { def main(args: Array[String]): Unit = { val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - val holder = new BufferHolder(new UnsafeRow(1000)) + val unsafeRow = new UnsafeRow(1000) + val holder = new BufferHolder(unsafeRow) holder.reset() - holder.grow(roundToWord(ARRAY_MAX / 2)) - holder.reset() - holder.grow(roundToWord(ARRAY_MAX / 2 + 8)) + assert(intercept[IllegalArgumentException] { + holder.grow(-1) + }.getMessage.contains("because the size is negative")) - holder.reset() - holder.grow(roundToWord(Integer.MAX_VALUE / 2)) + // while to reuse a buffer may happen, this test checks whether the buffer can be grown + holder.grow(ARRAY_MAX / 2) + assert(unsafeRow.getSizeInBytes % 8 == 0) - holder.reset() - holder.grow(roundToWord(Integer.MAX_VALUE)) - } + holder.grow(ARRAY_MAX / 2 + 7) + assert(unsafeRow.getSizeInBytes % 8 == 0) + + holder.grow(Integer.MAX_VALUE / 2) + assert(unsafeRow.getSizeInBytes % 8 == 0) + + holder.grow(ARRAY_MAX - holder.totalSize()) + assert(unsafeRow.getSizeInBytes % 8 == 0) - private def roundToWord(len: Int): Int = { - ByteArrayMethods.roundNumberOfBytesToNearestWord(len) + assert(intercept[IllegalArgumentException] { + holder.grow(ARRAY_MAX + 1 - holder.totalSize()) + }.getMessage.contains("because the size after growing")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala index c7c386b5b838a..4e0f903a030aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala @@ -23,17 +23,15 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow class BufferHolderSuite extends SparkFunSuite { test("SPARK-16071 Check the size limit to avoid integer overflow") { - var e = intercept[UnsupportedOperationException] { + assert(intercept[UnsupportedOperationException] { new BufferHolder(new UnsafeRow(Int.MaxValue / 8)) - } - assert(e.getMessage.contains("too many fields")) + }.getMessage.contains("too many fields")) val holder = new BufferHolder(new UnsafeRow(1000)) holder.reset() holder.grow(1000) - e = intercept[UnsupportedOperationException] { + assert(intercept[IllegalArgumentException] { holder.grow(Integer.MAX_VALUE) - } - assert(e.getMessage.contains("exceeds size limitation")) + }.getMessage.contains("exceeds size limitation")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala new file mode 100644 index 0000000000000..55569b6f2933e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class CodeBlockSuite extends SparkFunSuite { + + test("Block interpolates string and ExprValue inputs") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val stringLiteral = "false" + val code = code"boolean $isNull = $stringLiteral;" + assert(code.toString == "boolean expr1_isNull = false;") + } + + test("Literals are folded into string code parts instead of block inputs") { + val value = JavaCode.variable("expr1", IntegerType) + val intLiteral = 1 + val code = code"int $value = $intLiteral;" + assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) + } + + test("Block.stripMargin") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code1 = + code""" + |boolean $isNull = false; + |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + val expected = + s""" + |boolean expr1_isNull = false; + |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim + assert(code1.toString == expected) + + val code2 = + code""" + >boolean $isNull = false; + >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>') + assert(code2.toString == expected) + } + + test("Block can capture input expr values") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code = + code""" + |boolean $isNull = false; + |int $value = -1; + """.stripMargin + val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }.toSet + assert(exprValues.size == 2) + assert(exprValues === Set(value, isNull)) + } + + test("concatenate blocks") { + val isNull1 = JavaCode.isNullVariable("expr1_isNull") + val value1 = JavaCode.variable("expr1", IntegerType) + val isNull2 = JavaCode.isNullVariable("expr2_isNull") + val value2 = JavaCode.variable("expr2", IntegerType) + val literal = JavaCode.literal("100", IntegerType) + + val code = + code""" + |boolean $isNull1 = false; + |int $value1 = -1;""".stripMargin + + code""" + |boolean $isNull2 = true; + |int $value2 = $literal;""".stripMargin + + val expected = + """ + |boolean expr1_isNull = false; + |int expr1 = -1; + |boolean expr2_isNull = true; + |int expr2 = 100;""".stripMargin.trim + + assert(code.toString == expected) + + val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }).toSet + assert(exprValues.size == 5) + assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) + } + + test("Throws exception when interpolating unexcepted object in code block") { + val obj = Tuple2(1, 1) + val e = intercept[IllegalArgumentException] { + code"$obj" + } + assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) + } + + test("transform expr in code block") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val code = + code""" + |callFunc(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + // We want to replace all occurrences of `expr` with the variable `aliasedParam`. + val aliasedCode = code.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam + } + val expected = + code""" + |callFunc(int $aliasedParam) { + | boolean $isNull = false; + | int $exprInFunc = $aliasedParam + 1; + |}""".stripMargin + assert(aliasedCode.toString == expected.toString) + } + + test ("transform expr in nested blocks") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val funcs = Seq("callFunc1", "callFunc2", "callFunc3") + val subBlocks = funcs.map { funcName => + code""" + |$funcName(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + } + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}" + val transformedBlock = block.transform { + case b: Block => b.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam + } + }.asInstanceOf[CodeBlock] + + val expected1 = + code""" + |callFunc1(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected2 = + code""" + |callFunc2(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected3 = + code""" + |callFunc3(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val exprValues = transformedBlock.children.flatMap { block => + block.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + } + }.toSet + + assert(transformedBlock.children(0).toString == expected1.toString) + assert(transformedBlock.children(1).toString == expected2.toString) + assert(transformedBlock.children(2).toString == expected3.toString) + assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) + assert(exprValues === Set(isNull, exprInFunc, aliasedParam)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index c4cde7091154b..0fec15bc42c17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite { assert(ret == "foo") } + test("embedFailure") { + import org.apache.commons.io.FileUtils + import java.io.File + val secretValue = String.valueOf(Math.random) + val tempFile = File.createTempFile("verifyembed", ".tmp") + tempFile.deleteOnExit() + val fname = tempFile.getAbsolutePath + + FileUtils.writeStringToFile(tempFile, secretValue) + + val xml = + s""" + | + |]> + |&embed; + """.stripMargin + val evaled = new UDFXPathUtil().evalString(xml, "/foo") + assert(evaled.isEmpty) + } + test("number eval") { var ret = util.evalNumber("truefalseb3c1-77", "a/c[2]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index bfa18a0919e45..c6f6d3abb860c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { // Test error message for invalid XML document val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) } - assert(e1.getCause.getMessage.contains("Invalid XML document") && - e1.getCause.getMessage.contains("/a>")) + assert(e1.getCause.getCause.getMessage.contains( + "XML document structures must start and end within the same entity.")) + assert(e1.getMessage.contains("/a>")) // Test error message for invalid xpath val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala index 21220b38968e8..788fedb3c8e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala @@ -56,7 +56,7 @@ class CheckCartesianProductsSuite extends PlanTest { val thrownException = the [AnalysisException] thrownBy { performCartesianProductCheck(joinType) } - assert(thrownException.message.contains("Detected cartesian product")) + assert(thrownException.message.contains("Detected implicit cartesian product")) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 3f41f4b144096..8d7c9bf220bc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -141,6 +140,30 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, expected) } + test("Column pruning for ScriptTransformation") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = + ScriptTransformation( + Seq('a, 'b), + "func", + Seq.empty, + input, + null).analyze + val optimized = Optimize.execute(query) + + val expected = + ScriptTransformation( + Seq('a, 'b), + "func", + Seq.empty, + Project( + Seq('a, 'b), + input), + null).analyze + + comparePlans(optimized, expected) + } + test("Column pruning on Filter") { val input = LocalRelation('a.int, 'b.string, 'c.double) val plan1 = Filter('a > 1, input).analyze @@ -157,10 +180,10 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on except/intersect/distinct") { val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Except(input, input)).analyze + val query = Project('a :: Nil, Except(input, input, isAll = false)).analyze comparePlans(Optimize.execute(query), query) - val query2 = Project('a :: Nil, Intersect(input, input)).analyze + val query2 = Project('a :: Nil, Intersect(input, input, isAll = false)).analyze comparePlans(Optimize.execute(query2), query2) val query3 = Project('a :: Nil, Distinct(input)).analyze comparePlans(Optimize.execute(query3), query3) @@ -370,5 +393,13 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized2, expected2.analyze) } + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + val input = LocalRelation('key.int, 'value.string) + val query = input.select('key).where(rand(0L) > 0.5).where('key < 10).analyze + val optimized = Optimize.execute(query) + val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze + comparePlans(optimized, expected) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 049a19b86f7cd..0c015f88e1e84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{LessThan, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -52,4 +53,21 @@ class ConvertToLocalRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Filter on LocalRelation should be turned into a single LocalRelation") { + val testRelation = LocalRelation( + LocalRelation('a.int, 'b.int).output, + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a1.int, 'b1.int).output, + InternalRow(1, 3) :: Nil) + + val filterAndProjectOnLocal = testRelation + .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) + .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6))) + + val optimized = Optimize.execute(filterAndProjectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e068f51044589..e4671f0d1cce6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, - BooleanSimplification) :: Nil + BooleanSimplification, + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,13 +210,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) } test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { @@ -232,12 +240,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-21479: Outer join no filter push down to preserved side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze - val left = x - val right = y.where(IsNotNull('a) && 'a === 1) - val correctAnswer = left.join(right, LeftOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin( + x, y.where("a".attr === 1), + x, y.where(IsNotNull('a) && 'a === 1), + LeftOuter) + } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..a36083b847043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -121,6 +121,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("OptimizedIn test: NULL IN (subquery) gets transformed to Filter(null)") { + val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a"))) + val originalQuery = + testRelation + .where(InSubquery(Seq(Literal.create(null, NullType)), subquery)) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: Inset optimization disabled as " + "list expression contains attribute)") { val originalQuery = @@ -176,6 +191,21 @@ class OptimizeInSuite extends PlanTest { } } + test("OptimizedIn test: one element in list gets transformed to EqualTo.") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -191,4 +221,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: In empty list gets transformed to `If` expression " + + "when value is nullable") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala index 7112c033eabce..36b083a540c3c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -47,7 +47,7 @@ class OptimizerExtendableSuite extends SparkFunSuite { DummyRule) :: Nil } - override def batches: Seq[Batch] = super.batches ++ myBatches + override def defaultBatches: Seq[Batch] = super.defaultBatches ++ myBatches } test("Extending batches possible") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala new file mode 100644 index 0000000000000..4fa4a7aadc8f2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_EXCLUDED_RULES + + +class OptimizerRuleExclusionSuite extends PlanTest { + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + private def verifyExcludedRules(optimizer: Optimizer, rulesToExclude: Seq[String]) { + val nonExcludableRules = optimizer.nonExcludableRules + + val excludedRuleNames = rulesToExclude.filter(!nonExcludableRules.contains(_)) + // Batches whose rules are all to be excluded should be removed as a whole. + val excludedBatchNames = optimizer.batches + .filter(batch => batch.rules.forall(rule => excludedRuleNames.contains(rule.ruleName))) + .map(_.name) + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRuleNames.foldLeft("")((l, r) => l + "," + r)) { + val batches = optimizer.batches + // Verify removed batches. + assert(batches.forall(batch => !excludedBatchNames.contains(batch.name))) + // Verify removed rules. + assert( + batches + .forall(batch => batch.rules.forall(rule => !excludedRuleNames.contains(rule.ruleName)))) + // Verify non-excludable rules retained. + nonExcludableRules.foreach { nonExcludableRule => + assert( + optimizer.batches + .exists(batch => batch.rules.exists(rule => rule.ruleName == nonExcludableRule))) + } + } + } + + test("Exclude a single rule from multiple batches") { + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + PushPredicateThroughJoin.ruleName)) + } + + test("Exclude multiple rules from single or multiple batches") { + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + CombineUnions.ruleName, + RemoveLiteralFromGroupExpressions.ruleName, + RemoveRepetitionFromGroupExpressions.ruleName)) + } + + test("Exclude non-existent rule with other valid rules") { + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + LimitPushDown.ruleName, + InferFiltersFromConstraints.ruleName, + "DummyRuleName")) + } + + test("Try to exclude some non-excludable rules") { + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + ReplaceIntersectWithSemiJoin.ruleName, + PullupCorrelatedPredicates.ruleName, + RewriteCorrelatedScalarSubquery.ruleName, + RewritePredicateSubquery.ruleName, + RewriteExceptAll.ruleName, + RewriteIntersectAll.ruleName)) + } + + test("Custom optimizer") { + val optimizer = new SimpleTestOptimizer() { + override def defaultBatches: Seq[Batch] = + Batch("push", Once, + PushDownPredicate, + PushPredicateThroughJoin, + PushProjectionThroughUnion) :: + Batch("pull", Once, + PullupCorrelatedPredicates) :: Nil + + override def nonExcludableRules: Seq[String] = + PushDownPredicate.ruleName :: + PullupCorrelatedPredicates.ruleName :: Nil + } + + verifyExcludedRules( + optimizer, + Seq( + PushDownPredicate.ruleName, + PushProjectionThroughUnion.ruleName, + PullupCorrelatedPredicates.ruleName)) + } + + test("Verify optimized plan after excluding CombineUnions rule") { + val excludedRules = Seq( + ConvertToLocalRelation.ruleName, + PropagateEmptyRelation.ruleName, + CombineUnions.ruleName) + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { + val optimizer = new SimpleTestOptimizer() + val originalQuery = testRelation.union(testRelation.union(testRelation)).analyze + val optimized = optimizer.execute(originalQuery) + comparePlans(originalQuery, optimized) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala index 6e183d81b7265..a22a81e9844d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala @@ -44,7 +44,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { EmptyFunctionRegistry, new SQLConf())) { val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) - override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches + override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches } test("check for invalid plan after execution of rule") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 169b8737d808b..8a5a55146726e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In('a, Seq(ListQuery(correlatedSubquery)))) + .where(InSubquery(Seq('a), ListQuery(correlatedSubquery))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala index 2319ab8046e56..dae5e6f3ee3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class RemoveRedundantSortsSuite extends PlanTest { @@ -42,15 +38,15 @@ class RemoveRedundantSortsSuite extends PlanTest { test("remove redundant order by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) - val correctAnswer = orderedPlan.select('a).analyze + val correctAnswer = orderedPlan.limit(2).select('a).analyze comparePlans(Optimize.execute(optimized), correctAnswer) } test("do not remove sort if the order is different") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc) val optimized = Optimize.execute(reorderedDifferently.analyze) val correctAnswer = reorderedDifferently.analyze comparePlans(optimized, correctAnswer) @@ -72,6 +68,14 @@ class RemoveRedundantSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("different sorts are not simplified if limit is in between") { + val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) + .orderBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = orderedPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("range is already sorted") { val inputPlan = Range(1L, 1000L, 1, 10) val orderedPlan = inputPlan.orderBy('id.asc) @@ -98,4 +102,37 @@ class RemoveRedundantSortsSuite extends PlanTest { val correctAnswer = groupedAndResorted.analyze comparePlans(optimized, correctAnswer) } + + test("remove two consecutive sorts") { + val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc) + val optimized = Optimize.execute(orderedTwice.analyze) + val correctAnswer = testRelation.orderBy('b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("remove sorts separated by Filter/Project operators") { + val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc) + val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze) + val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze + comparePlans(optimizedWithProject, correctAnswerWithProject) + + val orderedTwiceWithFilter = + testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze) + val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithFilter, correctAnswerWithFilter) + + val orderedTwiceWithBoth = + testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze) + val correctAnswerWithBoth = + testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithBoth, correctAnswerWithBoth) + + val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc) + val optimizedThrice = Optimize.execute(orderedThrice.analyze) + val correctAnswerThrice = testRelation.select('b).where('b > Literal(0)) + .select(('b + 1).as('c)).orderBy('c.asc).analyze + comparePlans(optimizedThrice, correctAnswerThrice) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 52dc2e9fb076c..3b1b2d588ef67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -42,7 +42,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) - val query = Intersect(table1, table2) + val query = Intersect(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -60,7 +60,7 @@ class ReplaceOperatorSuite extends PlanTest { val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) val table3 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -79,7 +79,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) val table2 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) - val query = Except(table1, table2) + val query = Except(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -99,7 +99,7 @@ class ReplaceOperatorSuite extends PlanTest { val table3 = Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -120,7 +120,7 @@ class ReplaceOperatorSuite extends PlanTest { val table3 = Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -141,7 +141,7 @@ class ReplaceOperatorSuite extends PlanTest { Filter(attributeB < 1, Filter(attributeA >= 2, table1))) val table3 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -158,7 +158,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) - val query = Except(table1, table2) + val query = Except(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -173,7 +173,7 @@ class ReplaceOperatorSuite extends PlanTest { val left = table.where('b < 1).select('a).as("left") val right = table.where('b < 3).select('a).as("right") - val query = Except(left, right) + val query = Except(left, right, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index aa8841109329c..da3923f8d6477 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.BooleanType class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -144,4 +145,55 @@ class SetOperationSuite extends PlanTest { Distinct(Union(query3 :: query4 :: Nil))).analyze comparePlans(distinctUnionCorrectAnswer2, optimized2) } + + test("EXCEPT ALL rewrite") { + val input = Except(testRelation, testRelation2, isAll = true) + val rewrittenPlan = RewriteExceptAll(input) + + val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) + .union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f)) + .groupBy('a, 'b, 'c)('a, 'b, 'c, sum('vcol).as("sum")) + .where(GreaterThan('sum, Literal(0L))).analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } + + test("INTERSECT ALL rewrite") { + val input = Intersect(testRelation, testRelation2, isAll = true) + val rewrittenPlan = RewriteIntersectAll(input) + val leftRelation = testRelation + .select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c) + val rightRelation = testRelation2 + .select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f) + val planFragment = leftRelation.union(rightRelation) + .groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"), + count('vcol2).as("vcol2_count"), 'a, 'b, 'c) + .where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)), + GreaterThanOrEqual('vcol2_count, Literal(1L)))) + .select('a, 'b, 'c, + If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count")) + .analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index b597c8e162c83..8ad7c12020b82 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -29,7 +30,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType} class SimplifyConditionalSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil + val batches = Batch("SimplifyConditionals", FixedPoint(50), + BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { @@ -43,6 +45,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) + val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a"))) + val isNullCond = IsNull(UnresolvedAttribute("b")) + val notCond = Not(UnresolvedAttribute("c")) + test("simplify if") { assertEquivalent( If(TrueLiteral, Literal(10), Literal(20)), @@ -57,6 +63,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(20)) } + test("remove unnecessary if when the outputs are semantic equivalence") { + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), + Subtract(Literal(10), Literal(1)), + Add(Literal(6), Literal(3))), + Literal(9)) + + // For non-deterministic condition, we don't remove the `If` statement. + assertEquivalent( + If(GreaterThan(Rand(0), Literal(0.5)), + Subtract(Literal(10), Literal(1)), + Add(Literal(6), Literal(3))), + If(GreaterThan(Rand(0), Literal(0.5)), + Literal(9), + Literal(9))) + } + test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( @@ -100,4 +123,47 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { None), CaseWhen(normalBranch :: trueBranch :: Nil, None)) } + + test("simplify CaseWhen if all the outputs are semantic equivalence") { + // When the conditions in `CaseWhen` are all deterministic, `CaseWhen` can be removed. + assertEquivalent( + CaseWhen((isNotNullCond, Subtract(Literal(3), Literal(2))) :: + (isNullCond, Literal(1)) :: + (notCond, Add(Literal(6), Literal(-5))) :: + Nil, + Add(Literal(2), Literal(-1))), + Literal(1) + ) + + // For non-deterministic conditions, we don't remove the `CaseWhen` statement. + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), Literal(2))) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (EqualTo(Rand(2), Literal(0.5)), Add(Literal(6), Literal(-5))) :: + Nil, + Add(Literal(2), Literal(-1))), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (EqualTo(Rand(2), Literal(0.5)), Literal(1)) :: + Nil, + Literal(1)) + ) + + // When we have mixture of deterministic and non-deterministic conditions, we remove + // the deterministic conditions from the tail until a non-deterministic one is seen. + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), Literal(2))) :: + (NonFoldableLiteral(true), Add(Literal(2), Literal(-1))) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Add(Literal(6), Literal(-5))) :: + (NonFoldableLiteral(false), Literal(1)) :: + Nil, + Add(Literal(2), Literal(-1))), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + Nil, + Literal(1)) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 633d86d495581..5452e72b38647 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select('c as 'sCol2, 'a as 'sCol1) checkRule(originalQuery, correctAnswer) } + + test("SPARK-24313: support binary type as map keys in GetMapValue") { + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index f67697eb86c26..baaf01800b33b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -58,8 +58,5 @@ class ErrorParserSuite extends SparkFunSuite { intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", "^^^") - intercept("select * from r except all select * from t", 1, 0, - "EXCEPT ALL is not supported", - "^^^") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index cb8a1fecb80a7..781fc1e957ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,19 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In('a, Seq(ListQuery(table("c").select('b))))) + InSubquery(Seq('a), ListQuery(table("c").select('b)))) + + assertEqual( + "(a, b, c) in (select d, e, f from g)", + InSubquery(Seq('a, 'b, 'c), ListQuery(table("g").select('d, 'e, 'f)))) + + assertEqual( + "(a, b) in (select c from d)", + InSubquery(Seq('a, 'b), ListQuery(table("d").select('c)))) + + assertEqual( + "(a) in (select b from c)", + InSubquery(Seq('a), ListQuery(table("c").select('b)))) } test("like expressions") { @@ -234,6 +246,11 @@ class ExpressionParserSuite extends PlanTest { intercept("foo(a x)", "extraneous input 'x'") } + test("lambda functions") { + assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr))) + assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr))) + } + test("window function expressions") { val func = 'foo.function(star()) def windowed( @@ -469,7 +486,7 @@ class ExpressionParserSuite extends PlanTest { Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) - intercept("1.20E-38BD", "DecimalType can only support precision up to 38") + intercept("1.20E-38BD", "decimal can only support precision up to 38") } test("strings") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 812bfdd7bb885..422bf97e30e7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType /** @@ -64,15 +65,16 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select * from a union select * from b", Distinct(a.union(b))) assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) assertEqual("select * from a union all select * from b", a.union(b)) - assertEqual("select * from a except select * from b", a.except(b)) - intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") - assertEqual("select * from a except distinct select * from b", a.except(b)) - assertEqual("select * from a minus select * from b", a.except(b)) - intercept("select * from a minus all select * from b", "MINUS ALL is not supported.") - assertEqual("select * from a minus distinct select * from b", a.except(b)) - assertEqual("select * from a intersect select * from b", a.intersect(b)) - intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") - assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + assertEqual("select * from a except select * from b", a.except(b, isAll = false)) + assertEqual("select * from a except distinct select * from b", a.except(b, isAll = false)) + assertEqual("select * from a except all select * from b", a.except(b, isAll = true)) + assertEqual("select * from a minus select * from b", a.except(b, isAll = false)) + assertEqual("select * from a minus all select * from b", a.except(b, isAll = true)) + assertEqual("select * from a minus distinct select * from b", a.except(b, isAll = false)) + assertEqual("select * from a " + + "intersect select * from b", a.intersect(b, isAll = false)) + assertEqual("select * from a intersect distinct select * from b", a.intersect(b, isAll = false)) + assertEqual("select * from a intersect all select * from b", a.intersect(b, isAll = true)) } test("common table expressions") { @@ -318,6 +320,16 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", expected) + + intercept( + """select * + |from t + |lateral view explode(x) expl + |pivot ( + | sum(x) + | FOR y IN ('a', 'b') + |)""".stripMargin, + "LATERAL cannot be used together with PIVOT in FROM clause") } test("joins") { @@ -582,6 +594,33 @@ class PlanParserSuite extends AnalysisTest { parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), UnresolvedHint("MAPJOIN", Seq($"t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + + comparePlans( + parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"), + UnresolvedHint("COALESCE", Seq(Literal(10)), + table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ REPARTITION(100) */ * FROM t"), + UnresolvedHint("REPARTITION", Seq(Literal(100)), + table("t").select(star()))) + + comparePlans( + parsePlan( + "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), + InsertIntoTable(table("s"), Map.empty, + UnresolvedHint("REPARTITION", Seq(Literal(100)), + UnresolvedHint("COALESCE", Seq(Literal(500)), + UnresolvedHint("COALESCE", Seq(Literal(10)), + table("t").select(star())))), overwrite = false, ifPartitionNotExists = false)) + + comparePlans( + parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"), + UnresolvedHint("BROADCASTJOIN", Seq($"u"), + UnresolvedHint("REPARTITION", Seq(Literal(100)), + table("t").select(star())))) + + intercept("SELECT /*+ COALESCE(30 + 50) */ * FROM t", "mismatched input") } test("SPARK-20854: select hint syntax with expressions") { @@ -668,4 +707,50 @@ class PlanParserSuite extends AnalysisTest { OneRowRelation().select('rtrim.function("c&^,.", "bc...,,,&&&ccc")) ) } + + test("precedence of set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + val c = table("c").select(star()) + val d = table("d").select(star()) + + val query1 = + """ + |SELECT * FROM a + |UNION + |SELECT * FROM b + |EXCEPT + |SELECT * FROM c + |INTERSECT + |SELECT * FROM d + """.stripMargin + + val query2 = + """ + |SELECT * FROM a + |UNION + |SELECT * FROM b + |EXCEPT ALL + |SELECT * FROM c + |INTERSECT ALL + |SELECT * FROM d + """.stripMargin + + assertEqual(query1, Distinct(a.union(b)).except(c.intersect(d, isAll = false), isAll = false)) + assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) + + // Now disable precedence enforcement to verify the old behaviour. + withSQLConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED.key -> "true") { + assertEqual(query1, + Distinct(a.union(b)).except(c, isAll = false).intersect(d, isAll = false)) + assertEqual(query2, Distinct(a.union(b)).except(c, isAll = true).intersect(d, isAll = true)) + } + + // Explicitly enable the precedence enforcement + withSQLConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED.key -> "false") { + assertEqual(query1, + Distinct(a.union(b)).except(c.intersect(d, isAll = false), isAll = false)) + assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index cc80a41df998d..ff0de0fb7c1f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -41,17 +41,17 @@ class TableIdentifierParserSuite extends SparkFunSuite { "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", - "view", "while", "year", "work", "transaction", "write", "isolation", "level", - "snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint", + "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", + "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", - "insert", "int", "into", "is", "lateral", "like", "local", "none", "null", + "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing") + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a37e06d922642..5ad748b6113d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -187,7 +187,7 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(tr1 .where('a.attr > 10) - .intersect(tr2.where('b.attr < 100)) + .intersect(tr2.where('b.attr < 100), isAll = false) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100, @@ -200,7 +200,7 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { val tr2 = LocalRelation('a.int, 'b.int, 'c.int) verifyConstraints(tr1 .where('a.attr > 10) - .except(tr2.where('b.attr < 100)) + .except(tr2.where('b.attr < 100), isAll = false) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 14041747fd20e..aaab3ff1bf128 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown` plus analysis barrier - * and make sure it can correctly skip sub-trees that have already been analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown`. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -60,31 +59,6 @@ class LogicalPlanSuite extends SparkFunSuite { assert(invocationCount === 2) } - test("transformUp skips all ready resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) - plan transformUp function - - assert(invocationCount === 0) - - invocationCount = 0 - plan transformDown function - assert(invocationCount === 0) - } - - test("transformUp skips partially resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan1 = AnalysisBarrier(Project(Nil, testRelation)) - val plan2 = Project(Nil, plan1) - plan2 transformUp function - - assert(invocationCount === 1) - - invocationCount = 0 - plan2 transformDown function - assert(invocationCount === 1) - } - test("isStreaming") { val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val incrementalRelation = LocalRelation( @@ -101,4 +75,22 @@ class LogicalPlanSuite extends SparkFunSuite { assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } + + test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6241d5cbb1d25..67740c3166471 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.plans +import org.scalactic.source import org.scalatest.Suite +import org.scalatest.Tag import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ @@ -33,6 +36,23 @@ import org.apache.spark.sql.internal.SQLConf */ trait PlanTest extends SparkFunSuite with PlanTestBase +trait CodegenInterpretedPlanTest extends PlanTest { + + override protected def test( + testName: String, + testTags: Tag*)(testFun: => Any)(implicit pos: source.Position): Unit = { + val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + val interpretedMode = CodegenObjectFactoryMode.NO_CODEGEN.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + super.test(testName + " (codegen path)", testTags: _*)(testFun)(pos) + } + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> interpretedMode) { + super.test(testName + " (interpreted path)", testTags: _*)(testFun)(pos) + } + } +} + /** * Provides helper methods for comparing plans, but without the overhead of * mandating a FunSuite. @@ -60,6 +80,8 @@ trait PlanTestBase extends PredicateHelper { self: Suite => Alias(a.child, a.name)(exprId = ExprId(0)) case ae: AggregateExpression => ae.copy(resultId = ExprId(0)) + case lv: NamedLambdaVariable => + lv.copy(exprId = ExprId(0), value = null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala new file mode 100644 index 0000000000000..9100e10ca0c09 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, Literal, NamedExpression} + + +class AnalysisHelperSuite extends SparkFunSuite { + + private var invocationCount = 0 + private val function: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + invocationCount += 1 + p + } + + private val exprFunction: PartialFunction[Expression, Expression] = { + case e: Literal => + invocationCount += 1 + Literal.TrueLiteral + } + + private def projectExprs: Seq[NamedExpression] = Alias(Literal.TrueLiteral, "A")() :: Nil + + test("setAnalyze is recursive") { + val plan = Project(Nil, LocalRelation()) + plan.setAnalyzed() + assert(plan.find(!_.analyzed).isEmpty) + } + + test("resolveOperator runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperators(function) + assert(invocationCount === 2) + } + + test("resolveOperatorsDown runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperatorsDown(function) + assert(invocationCount === 2) + } + + test("resolveExpressions runs on operators recursively") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.resolveExpressions(exprFunction) + assert(invocationCount === 2) + } + + test("resolveOperator skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperators(function) + assert(invocationCount === 0) + } + + test("resolveOperatorsDown skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperatorsDown(function) + assert(invocationCount === 0) + } + + test("resolveExpressions skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.setAnalyzed() + plan.resolveExpressions(exprFunction) + assert(invocationCount === 0) + } + + test("resolveOperator skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperators(function) + assert(invocationCount === 1) + } + + test("resolveOperatorsDown skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperatorsDown(function) + assert(invocationCount === 1) + } + + test("resolveExpressions skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(projectExprs, LocalRelation()) + val plan2 = Project(projectExprs, plan1) + plan1.setAnalyzed() + plan2.resolveExpressions(exprFunction) + assert(invocationCount === 1) + } + + test("do not allow transform in analyzer") { + val plan = Project(Nil, LocalRelation()) + // These should be OK since we are not in the analzyer + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + + // The following should fail in the analyzer scope + AnalysisHelper.markInAnalyzer { + intercept[RuntimeException] { plan.transform { case p: Project => p } } + intercept[RuntimeException] { plan.transformUp { case p: Project => p } } + intercept[RuntimeException] { plan.transformDown { case p: Project => p } } + intercept[RuntimeException] { plan.transformAllExpressions { case lit: Literal => lit } } + } + } + + test("allow transform in resolveOperators in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + plan.resolveOperators { case p: Project => p.transform { case p: Project => p } } + plan.resolveOperatorsDown { case p: Project => p.transform { case p: Project => p } } + plan.resolveExpressions { case lit: Literal => + Project(Nil, LocalRelation()).transform { case p: Project => p } + lit + } + } + } + + test("allow transform with allowInvokingTransformsInAnalyzer in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 43440d51dede6..47bfa62569583 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -357,6 +357,29 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 3) } + test("evaluateInSet with all zeros") { + validateEstimatedStats( + Filter(InSet(attrString, Set(3, 4, 5)), + StatsTestPlan(Seq(attrString), 0, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(0))), + expectedRowCount = 0) + } + + test("evaluateInSet with string") { + validateEstimatedStats( + Filter(InSet(attrString, Set("A0")), + StatsTestPlan(Seq(attrString), 10, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), + expectedRowCount = 1) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 84d0ba7bef642..b7092f4c42d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite { val right = JsonMethods.parse(rightJson) assert(left == right) } + + test("transform works on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("withNewChildren on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala index 9d285916bcf42..229e32479082c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -104,4 +104,40 @@ class ComplexDataSuite extends SparkFunSuite { // The copied data should not be changed externally. assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") } + + test("SPARK-24659: GenericArrayData.equals should respect element type differences") { + import scala.reflect.ClassTag + + // Expected positive cases + def arraysShouldEqual[T: ClassTag](element: T*): Unit = { + val array1 = new GenericArrayData(Array[T](element: _*)) + val array2 = new GenericArrayData(Array[T](element: _*)) + assert(array1.equals(array2)) + } + arraysShouldEqual(true, false) // Boolean + arraysShouldEqual(0.toByte, 123.toByte, Byte.MinValue, Byte.MaxValue) // Byte + arraysShouldEqual(0.toShort, 123.toShort, Short.MinValue, Short.MaxValue) // Short + arraysShouldEqual(0, 123, -65536, Int.MinValue, Int.MaxValue) // Int + arraysShouldEqual(0L, 123L, -65536L, Long.MinValue, Long.MaxValue) // Long + arraysShouldEqual(0.0F, 123.0F, Float.MinValue, Float.MaxValue, Float.MinPositiveValue, + Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN) // Float + arraysShouldEqual(0.0, 123.0, Double.MinValue, Double.MaxValue, Double.MinPositiveValue, + Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN) // Double + arraysShouldEqual(Array[Byte](123.toByte), Array[Byte](), null) // SQL Binary + arraysShouldEqual(UTF8String.fromString("foo"), null) // SQL String + + // Expected negative cases + // Spark SQL considers cases like array vs array to be incompatible, + // so an underlying implementation of array type should return false in such cases. + def arraysShouldNotEqual[T: ClassTag, U: ClassTag](element1: T, element2: U): Unit = { + val array1 = new GenericArrayData(Array[T](element1)) + val array2 = new GenericArrayData(Array[U](element2)) + assert(!array1.equals(array2)) + } + arraysShouldNotEqual(true, 1) // Boolean <-> Int + arraysShouldNotEqual(123.toByte, 123) // Byte <-> Int + arraysShouldNotEqual(123.toByte, 123L) // Byte <-> Long + arraysShouldNotEqual(123.toShort, 123) // Short <-> Int + arraysShouldNotEqual(123, 123L) // Int <-> Long + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 625ff38943fa3..2423668392231 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -490,24 +490,36 @@ class DateTimeUtilsSuite extends SparkFunSuite { c1.set(1997, 1, 28, 10, 30, 0) val c2 = Calendar.getInstance() c2.set(1996, 9, 30, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) - c2.set(2000, 1, 28, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(2000, 1, 29, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(1996, 2, 31, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, true, c1.getTimeZone) === 3.94959677) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, false, c1.getTimeZone) + === 3.9495967741935485) + Seq(true, false).foreach { roundOff => + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === 11) + } val c3 = Calendar.getInstance(TimeZonePST) c3.set(2000, 1, 28, 16, 0, 0) val c4 = Calendar.getInstance(TimeZonePST) c4.set(1997, 1, 28, 16, 0, 0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZonePST) === 36.0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZoneGMT) === 35.90322581) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, false, TimeZoneGMT) + === 35.903225806451616) } test("from UTC timestamp") { @@ -650,18 +662,18 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(daysToMillis(16800, TimeZoneGMT) === c.getTimeInMillis) // There are some days are skipped entirely in some timezone, skip them here. - val skipped_days = Map[String, Int]( - "Kwajalein" -> 8632, - "Pacific/Apia" -> 15338, - "Pacific/Enderbury" -> 9131, - "Pacific/Fakaofo" -> 15338, - "Pacific/Kiritimati" -> 9131, - "Pacific/Kwajalein" -> 8632, - "MIT" -> 15338) + val skipped_days = Map[String, Set[Int]]( + "Kwajalein" -> Set(8632), + "Pacific/Apia" -> Set(15338), + "Pacific/Enderbury" -> Set(9130, 9131), + "Pacific/Fakaofo" -> Set(15338), + "Pacific/Kiritimati" -> Set(9130, 9131), + "Pacific/Kwajalein" -> Set(8632), + "MIT" -> Set(15338)) for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { - val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue) + val skipped = skipped_days.getOrElse(tz.getID, Set.empty) (-20000 to 20000).foreach { d => - if (d != skipped) { + if (!skipped.contains(d)) { assert(millisToDays(daysToMillis(d, tz), tz) === d, s"Round trip of ${d} did not work in tz ${tz}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 5a86f4055dce7..122a3125ee2c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -154,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite { left.merge(right) }.getMessage assert(message.equals("Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types FloatType and LongType")) + "Failed to merge incompatible data types float and bigint")) } test("existsRecursively") { @@ -452,4 +452,30 @@ class DataTypeSuite extends SparkFunSuite { new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)), new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), false) + + test("SPARK-25031: MapType should produce current formatted string for complex types") { + val keyType: DataType = StructType(Seq( + StructField("a", DataTypes.IntegerType), + StructField("b", DataTypes.IntegerType))) + + val valueType: DataType = StructType(Seq( + StructField("c", DataTypes.IntegerType), + StructField("d", DataTypes.IntegerType))) + + val builder = new StringBuilder + + MapType(keyType, valueType).buildFormattedString(prefix = "", builder = builder) + + val result = builder.toString() + val expected = + """-- key: struct + | |-- a: integer (nullable = true) + | |-- b: integer (nullable = true) + |-- value: struct (valueContainsNull = true) + | |-- c: integer (nullable = true) + | |-- d: integer (nullable = true) + |""".stripMargin + + assert(result === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala new file mode 100644 index 0000000000000..d92f52f3248aa --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.Cast + +class DataTypeWriteCompatibilitySuite extends SparkFunSuite { + private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DateType, TimestampType, StringType, BinaryType) + + private val point2 = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))) + + private val widerPoint2 = StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType, nullable = false))) + + private val point3 = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false), + StructField("z", FloatType))) + + private val simpleContainerTypes = Seq( + ArrayType(LongType), ArrayType(LongType, containsNull = false), MapType(StringType, DoubleType), + MapType(StringType, DoubleType, valueContainsNull = false), point2, point3) + + private val nestedContainerTypes = Seq(ArrayType(point2, containsNull = false), + MapType(StringType, point3, valueContainsNull = false)) + + private val allNonNullTypes = Seq( + atomicTypes, simpleContainerTypes, nestedContainerTypes, Seq(CalendarIntervalType)).flatten + + test("Check NullType is incompatible with all other types") { + allNonNullTypes.foreach { t => + assertSingleError(NullType, t, "nulls", s"Should not allow writing None to type $t") { err => + assert(err.contains(s"incompatible with $t")) + } + } + } + + test("Check each type with itself") { + allNonNullTypes.foreach { t => + assertAllowed(t, t, "t", s"Should allow writing type to itself $t") + } + } + + test("Check atomic types: write allowed only when casting is safe") { + atomicTypes.foreach { w => + atomicTypes.foreach { r => + if (Cast.canSafeCast(w, r)) { + assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") + + } else { + assertSingleError(w, r, "t", + s"Should not allow writing $w to $r because cast is not safe") { err => + assert(err.contains("'t'"), "Should include the field name context") + assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") + assert(err.contains(s"$w"), "Should include write type") + assert(err.contains(s"$r"), "Should include read type") + } + } + } + } + } + + test("Check struct types: missing required field") { + val missingRequiredField = StructType(Seq(StructField("x", FloatType, nullable = false))) + assertSingleError(missingRequiredField, point2, "t", + "Should fail because required field 'y' is missing") { err => + assert(err.contains("'t'"), "Should include the struct name for context") + assert(err.contains("'y'"), "Should include the nested field name") + assert(err.contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: missing starting field, matched by position") { + val missingRequiredField = StructType(Seq(StructField("y", FloatType, nullable = false))) + + // should have 2 errors: names x and y don't match, and field y is missing + assertNumErrors(missingRequiredField, point2, "t", + "Should fail because field 'x' is matched to field 'y' and required field 'y' is missing", 2) + { errs => + assert(errs(0).contains("'t'"), "Should include the struct name for context") + assert(errs(0).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(0).contains("field name does not match"), "Should identify name problem") + + assert(errs(1).contains("'t'"), "Should include the struct name for context") + assert(errs(1).contains("'y'"), "Should include the _last_ nested fields of the read schema") + assert(errs(1).contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: missing middle field, matched by position") { + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val expectedStruct = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false), + StructField("z", FloatType, nullable = true))) + + // types are compatible: (req int, req int) => (req int, req int, opt int) + // but this should still fail because the names do not match. + + assertNumErrors(missingMiddleField, expectedStruct, "t", + "Should fail because field 'y' is matched to field 'z'", 2) { errs => + assert(errs(0).contains("'t'"), "Should include the struct name for context") + assert(errs(0).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(0).contains("field name does not match"), "Should identify name problem") + + assert(errs(1).contains("'t'"), "Should include the struct name for context") + assert(errs(1).contains("'z'"), "Should include the nested field name") + assert(errs(1).contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: generic colN names are ignored") { + val missingMiddleField = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + val expectedStruct = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))) + + // types are compatible: (req int, req int) => (req int, req int) + // names don't match, but match the naming convention used by Spark to fill in names + + assertAllowed(missingMiddleField, expectedStruct, "t", + "Should succeed because column names are ignored") + } + + test("Check struct types: required field is optional") { + val requiredFieldIsOptional = StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType, nullable = false))) + + assertSingleError(requiredFieldIsOptional, point2, "t", + "Should fail because required field 'x' is optional") { err => + assert(err.contains("'t.x'"), "Should include the nested field name context") + assert(err.contains("Cannot write nullable values to non-null field")) + } + } + + test("Check struct types: data field would be dropped") { + assertSingleError(point3, point2, "t", + "Should fail because field 'z' would be dropped") { err => + assert(err.contains("'t'"), "Should include the struct name for context") + assert(err.contains("'z'"), "Should include the extra field name") + assert(err.contains("Cannot write extra fields")) + } + } + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(widerPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check struct types: type promotion is allowed") { + assertAllowed(point2, widerPoint2, "t", + "Should allow widening float fields x and y to double") + } + + ignore("Check struct types: missing optional field is allowed") { + // built-in data sources do not yet support missing fields when optional + assertAllowed(point2, point3, "t", + "Should allow writing point (x,y) to point(x,y,z=null)") + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfLong, arrayOfInt, "arr", + "Should not allow array of longs to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check array types: type promotion is allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + assertAllowed(arrayOfInt, arrayOfLong, "arr", + "Should allow array of int written to array of long column") + } + + test("Check array types: cannot write optional to required elements") { + val arrayOfRequired = ArrayType(LongType, containsNull = false) + val arrayOfOptional = ArrayType(LongType) + + assertSingleError(arrayOfOptional, arrayOfRequired, "arr", + "Should not allow array of optional elements to array of required elements") { err => + assert(err.contains("'arr'"), "Should include type name context") + assert(err.contains("Cannot write nullable elements to array of non-nulls")) + } + } + + test("Check array types: writing required to optional elements is allowed") { + val arrayOfRequired = ArrayType(LongType, containsNull = false) + val arrayOfOptional = ArrayType(LongType) + + assertAllowed(arrayOfRequired, arrayOfOptional, "arr", + "Should allow array of required elements to array of optional elements") + } + + test("Check map value types: unsafe casts are not allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfLong, mapOfInt, "m", + "Should not allow map of longs to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map value types: type promotion is allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertAllowed(mapOfInt, mapOfLong, "m", "Should allow map of int written to map of long column") + } + + test("Check map value types: cannot write optional to required values") { + val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false) + val mapOfOptional = MapType(StringType, LongType) + + assertSingleError(mapOfOptional, mapOfRequired, "m", + "Should not allow map of optional values to map of required values") { err => + assert(err.contains("'m'"), "Should include type name context") + assert(err.contains("Cannot write nullable values to map of non-nulls")) + } + } + + test("Check map value types: writing required to optional values is allowed") { + val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false) + val mapOfOptional = MapType(StringType, LongType) + + assertAllowed(mapOfRequired, mapOfOptional, "m", + "Should allow map of required elements to map of optional elements") + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyLong, mapKeyInt, "m", + "Should not allow map of long keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: type promotion is allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertAllowed(mapKeyInt, mapKeyLong, "m", + "Should allow map of int written to map of long column") + } + + test("Check types with multiple errors") { + val readType = StructType(Seq( + StructField("a", ArrayType(DoubleType, containsNull = false)), + StructField("arr_of_structs", ArrayType(point2, containsNull = false)), + StructField("bad_nested_type", ArrayType(StringType)), + StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), + StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), + StructField("x", IntegerType, nullable = false), + StructField("missing1", StringType, nullable = false), + StructField("missing2", StringType) + )) + + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val writeType = StructType(Seq( + StructField("a", ArrayType(StringType)), + StructField("arr_of_structs", ArrayType(point3)), + StructField("bad_nested_type", point3), + StructField("m", MapType(DoubleType, DoubleType)), + StructField("map_of_structs", MapType(StringType, missingMiddleField)), + StructField("y", LongType) + )) + + assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => + assert(errs(0).contains("'top.a.element'"), "Should identify bad type") + assert(errs(0).contains("Cannot safely cast")) + assert(errs(0).contains("StringType to DoubleType")) + + assert(errs(1).contains("'top.a'"), "Should identify bad type") + assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") + assert(errs(2).contains("'z'"), "Should identify bad field") + assert(errs(2).contains("Cannot write extra fields to struct")) + + assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") + assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") + assert(errs(4).contains("is incompatible with")) + + assert(errs(5).contains("'top.m.key'"), "Should identify bad type") + assert(errs(5).contains("Cannot safely cast")) + assert(errs(5).contains("DoubleType to LongType")) + + assert(errs(6).contains("'top.m.value'"), "Should identify bad type") + assert(errs(6).contains("Cannot safely cast")) + assert(errs(6).contains("DoubleType to FloatType")) + + assert(errs(7).contains("'top.m'"), "Should identify bad type") + assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(8).contains("field name does not match"), "Should identify name problem") + + assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(9).contains("'z'"), "Should identify missing field") + assert(errs(9).contains("missing fields"), "Should detect missing field") + + assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") + assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(11).contains("'top.x'"), "Should identify bad type") + assert(errs(11).contains("Cannot safely cast")) + assert(errs(11).contains("LongType to IntegerType")) + + assert(errs(12).contains("'top'"), "Should identify bad type") + assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(12).contains("field name does not match"), "Should identify name problem") + + assert(errs(13).contains("'top'"), "Should identify bad type") + assert(errs(13).contains("'missing1'"), "Should identify missing field") + assert(errs(13).contains("missing fields"), "Should detect missing field") + } + } + + // Helper functions + + def assertAllowed(writeType: DataType, readType: DataType, name: String, desc: String): Unit = { + assert( + DataType.canWrite(writeType, readType, analysis.caseSensitiveResolution, name, + errMsg => fail(s"Should not produce errors but was called with: $errMsg")) === true, desc) + } + + def assertSingleError( + writeType: DataType, + readType: DataType, + name: String, + desc: String) + (errFunc: String => Unit): Unit = { + assertNumErrors(writeType, readType, name, desc, 1) { errs => + errFunc(errs.head) + } + } + + def assertNumErrors( + writeType: DataType, + readType: DataType, + name: String, + desc: String, + numErrs: Int) + (errFunc: Seq[String] => Unit): Unit = { + val errs = new mutable.ArrayBuffer[String]() + assert( + DataType.canWrite(writeType, readType, analysis.caseSensitiveResolution, name, + errMsg => errs += errMsg) === false, desc) + assert(errs.size === numErrs, s"Should produce $numErrs error messages") + errFunc(errs) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala new file mode 100644 index 0000000000000..210e65708170f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.SparkFunSuite + +class MetadataSuite extends SparkFunSuite { + test("String Metadata") { + val meta = new MetadataBuilder().putString("key", "value").build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getString("key") === "value") + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getString("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Long Metadata") { + val meta = new MetadataBuilder().putLong("key", 12).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getLong("key") === 12) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getLong("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Double Metadata") { + val meta = new MetadataBuilder().putDouble("key", 12).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getDouble("key") === 12) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getDouble("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Boolean Metadata") { + val meta = new MetadataBuilder().putBoolean("key", true).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getBoolean("key") === true) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getBoolean("no_such_key")) + intercept[ClassCastException](meta.getString("key")) + } + + test("Null Metadata") { + val meta = new MetadataBuilder().putNull("key").build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getString("key") === null) + assert(meta.getDouble("key") === 0) + assert(meta.getLong("key") === 0) + assert(meta.getBoolean("key") === false) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getLong("no_such_key")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index c6ca8bb005429..53a78c94aa6fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.StructType.fromDDL class StructTypeSuite extends SparkFunSuite { @@ -37,4 +38,36 @@ class StructTypeSuite extends SparkFunSuite { val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage assert(e.contains("Available fields: a, b")) } + + test("SPARK-24849: toDDL - simple struct") { + val struct = StructType(Seq(StructField("a", IntegerType))) + + assert(struct.toDDL == "`a` INT") + } + + test("SPARK-24849: round trip toDDL - fromDDL") { + val struct = new StructType().add("a", IntegerType).add("b", StringType) + + assert(fromDDL(struct.toDDL) === struct) + } + + test("SPARK-24849: round trip fromDDL - toDDL") { + val struct = "`a` MAP,`b` INT" + + assert(fromDDL(struct).toDDL === struct) + } + + test("SPARK-24849: toDDL must take into account case of fields.") { + val struct = new StructType() + .add("metaData", new StructType().add("eventId", StringType)) + + assert(struct.toDDL == "`metaData` STRUCT<`eventId`: STRING>") + } + + test("SPARK-24849: toDDL should output field's comment") { + val struct = StructType(Seq( + StructField("b", BooleanType).withComment("Field's comment"))) + + assert(struct.toDDL == """`b` BOOLEAN COMMENT 'Field\'s comment'""") + } } diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt new file mode 100644 index 0000000000000..2215ed91e2018 --- /dev/null +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -0,0 +1,704 @@ +================================================================================================ +Pushdown for many distinct value case +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8970 / 9122 1.8 570.3 1.0X +Parquet Vectorized (Pushdown) 471 / 491 33.4 30.0 19.0X +Native ORC Vectorized 7661 / 7853 2.1 487.0 1.2X +Native ORC Vectorized (Pushdown) 1134 / 1161 13.9 72.1 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9246 / 9297 1.7 587.8 1.0X +Parquet Vectorized (Pushdown) 480 / 488 32.8 30.5 19.3X +Native ORC Vectorized 7838 / 7850 2.0 498.3 1.2X +Native ORC Vectorized (Pushdown) 1054 / 1118 14.9 67.0 8.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8989 / 9100 1.7 571.5 1.0X +Parquet Vectorized (Pushdown) 448 / 467 35.1 28.5 20.1X +Native ORC Vectorized 7680 / 7768 2.0 488.3 1.2X +Native ORC Vectorized (Pushdown) 1067 / 1118 14.7 67.8 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9115 / 9266 1.7 579.5 1.0X +Parquet Vectorized (Pushdown) 466 / 492 33.7 29.7 19.5X +Native ORC Vectorized 7800 / 7914 2.0 495.9 1.2X +Native ORC Vectorized (Pushdown) 1075 / 1102 14.6 68.4 8.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9099 / 9237 1.7 578.5 1.0X +Parquet Vectorized (Pushdown) 462 / 475 34.1 29.3 19.7X +Native ORC Vectorized 7847 / 7925 2.0 498.9 1.2X +Native ORC Vectorized (Pushdown) 1078 / 1114 14.6 68.5 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 19303 / 19547 0.8 1227.3 1.0X +Parquet Vectorized (Pushdown) 19924 / 20089 0.8 1266.7 1.0X +Native ORC Vectorized 18725 / 19079 0.8 1190.5 1.0X +Native ORC Vectorized (Pushdown) 19310 / 19492 0.8 1227.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8117 / 8323 1.9 516.1 1.0X +Parquet Vectorized (Pushdown) 484 / 494 32.5 30.8 16.8X +Native ORC Vectorized 6811 / 7036 2.3 433.0 1.2X +Native ORC Vectorized (Pushdown) 1061 / 1082 14.8 67.5 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8105 / 8140 1.9 515.3 1.0X +Parquet Vectorized (Pushdown) 478 / 505 32.9 30.4 17.0X +Native ORC Vectorized 6914 / 7211 2.3 439.6 1.2X +Native ORC Vectorized (Pushdown) 1044 / 1064 15.1 66.4 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7983 / 8116 2.0 507.6 1.0X +Parquet Vectorized (Pushdown) 464 / 487 33.9 29.5 17.2X +Native ORC Vectorized 6703 / 6774 2.3 426.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7942 / 7983 2.0 504.9 1.0X +Parquet Vectorized (Pushdown) 468 / 479 33.6 29.7 17.0X +Native ORC Vectorized 6677 / 6779 2.4 424.5 1.2X +Native ORC Vectorized (Pushdown) 1021 / 1068 15.4 64.9 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7909 / 7958 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 485 / 494 32.4 30.8 16.3X +Native ORC Vectorized 6751 / 6846 2.3 429.2 1.2X +Native ORC Vectorized (Pushdown) 1043 / 1077 15.1 66.3 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8010 / 8033 2.0 509.2 1.0X +Parquet Vectorized (Pushdown) 472 / 489 33.3 30.0 17.0X +Native ORC Vectorized 6655 / 6808 2.4 423.1 1.2X +Native ORC Vectorized (Pushdown) 1015 / 1067 15.5 64.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8983 / 9035 1.8 571.1 1.0X +Parquet Vectorized (Pushdown) 2204 / 2231 7.1 140.1 4.1X +Native ORC Vectorized 7864 / 8011 2.0 500.0 1.1X +Native ORC Vectorized (Pushdown) 2674 / 2789 5.9 170.0 3.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12723 / 12903 1.2 808.9 1.0X +Parquet Vectorized (Pushdown) 9112 / 9282 1.7 579.3 1.4X +Native ORC Vectorized 12090 / 12230 1.3 768.7 1.1X +Native ORC Vectorized (Pushdown) 9242 / 9372 1.7 587.6 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16453 / 16678 1.0 1046.1 1.0X +Parquet Vectorized (Pushdown) 15997 / 16262 1.0 1017.0 1.0X +Native ORC Vectorized 16652 / 17070 0.9 1058.7 1.0X +Native ORC Vectorized (Pushdown) 15843 / 16112 1.0 1007.2 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17098 / 17254 0.9 1087.1 1.0X +Parquet Vectorized (Pushdown) 17302 / 17529 0.9 1100.1 1.0X +Native ORC Vectorized 16790 / 17098 0.9 1067.5 1.0X +Native ORC Vectorized (Pushdown) 17329 / 17914 0.9 1101.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17088 / 17392 0.9 1086.4 1.0X +Parquet Vectorized (Pushdown) 17609 / 17863 0.9 1119.5 1.0X +Native ORC Vectorized 18334 / 69831 0.9 1165.7 0.9X +Native ORC Vectorized (Pushdown) 17465 / 17629 0.9 1110.4 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16903 / 17233 0.9 1074.6 1.0X +Parquet Vectorized (Pushdown) 16945 / 17032 0.9 1077.3 1.0X +Native ORC Vectorized 16377 / 16762 1.0 1041.2 1.0X +Native ORC Vectorized (Pushdown) 16950 / 17212 0.9 1077.7 1.0X + + +================================================================================================ +Pushdown for few distinct value case (use dictionary encoding) +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7245 / 7322 2.2 460.7 1.0X +Parquet Vectorized (Pushdown) 378 / 389 41.6 24.0 19.2X +Native ORC Vectorized 6720 / 6778 2.3 427.2 1.1X +Native ORC Vectorized (Pushdown) 1009 / 1032 15.6 64.2 7.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7627 / 7795 2.1 484.9 1.0X +Parquet Vectorized (Pushdown) 384 / 406 41.0 24.4 19.9X +Native ORC Vectorized 6724 / 7824 2.3 427.5 1.1X +Native ORC Vectorized (Pushdown) 968 / 986 16.3 61.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7157 / 7534 2.2 455.0 1.0X +Parquet Vectorized (Pushdown) 542 / 565 29.0 34.5 13.2X +Native ORC Vectorized 6716 / 7214 2.3 427.0 1.1X +Native ORC Vectorized (Pushdown) 1212 / 1288 13.0 77.0 5.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7368 / 7552 2.1 468.4 1.0X +Parquet Vectorized (Pushdown) 544 / 556 28.9 34.6 13.5X +Native ORC Vectorized 6740 / 6867 2.3 428.5 1.1X +Native ORC Vectorized (Pushdown) 1230 / 1426 12.8 78.2 6.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7427 / 7734 2.1 472.2 1.0X +Parquet Vectorized (Pushdown) 556 / 568 28.3 35.4 13.3X +Native ORC Vectorized 6847 / 7059 2.3 435.3 1.1X +Native ORC Vectorized (Pushdown) 1226 / 1230 12.8 77.9 6.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all distinct string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16998 / 17311 0.9 1080.7 1.0X +Parquet Vectorized (Pushdown) 16977 / 17250 0.9 1079.4 1.0X +Native ORC Vectorized 18447 / 19852 0.9 1172.8 0.9X +Native ORC Vectorized (Pushdown) 16614 / 17102 0.9 1056.3 1.0X + + +================================================================================================ +Pushdown benchmark for StringStartsWith +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '10%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9705 / 10814 1.6 617.0 1.0X +Parquet Vectorized (Pushdown) 3086 / 3574 5.1 196.2 3.1X +Native ORC Vectorized 10094 / 10695 1.6 641.8 1.0X +Native ORC Vectorized (Pushdown) 9611 / 9999 1.6 611.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '1000%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8016 / 8183 2.0 509.7 1.0X +Parquet Vectorized (Pushdown) 444 / 457 35.4 28.2 18.0X +Native ORC Vectorized 6970 / 7169 2.3 443.2 1.2X +Native ORC Vectorized (Pushdown) 7447 / 7503 2.1 473.5 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '786432%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7908 / 8046 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 408 / 429 38.6 25.9 19.4X +Native ORC Vectorized 7021 / 7100 2.2 446.4 1.1X +Native ORC Vectorized (Pushdown) 7310 / 7490 2.2 464.8 1.1X + + +================================================================================================ +Pushdown benchmark for decimal +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4546 / 4743 3.5 289.0 1.0X +Parquet Vectorized (Pushdown) 161 / 175 98.0 10.2 28.3X +Native ORC Vectorized 5721 / 5842 2.7 363.7 0.8X +Native ORC Vectorized (Pushdown) 1019 / 1070 15.4 64.8 4.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6340 / 7236 2.5 403.1 1.0X +Parquet Vectorized (Pushdown) 3052 / 3164 5.2 194.1 2.1X +Native ORC Vectorized 8370 / 9214 1.9 532.1 0.8X +Native ORC Vectorized (Pushdown) 4137 / 4242 3.8 263.0 1.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12976 / 13249 1.2 825.0 1.0X +Parquet Vectorized (Pushdown) 12655 / 13570 1.2 804.6 1.0X +Native ORC Vectorized 15562 / 15950 1.0 989.4 0.8X +Native ORC Vectorized (Pushdown) 15042 / 15668 1.0 956.3 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 14303 / 14616 1.1 909.3 1.0X +Parquet Vectorized (Pushdown) 14380 / 14649 1.1 914.3 1.0X +Native ORC Vectorized 16964 / 17358 0.9 1078.5 0.8X +Native ORC Vectorized (Pushdown) 17255 / 17874 0.9 1097.0 0.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4701 / 6416 3.3 298.9 1.0X +Parquet Vectorized (Pushdown) 128 / 164 122.8 8.1 36.7X +Native ORC Vectorized 5698 / 7904 2.8 362.3 0.8X +Native ORC Vectorized (Pushdown) 913 / 942 17.2 58.0 5.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5376 / 5461 2.9 341.8 1.0X +Parquet Vectorized (Pushdown) 1479 / 1543 10.6 94.0 3.6X +Native ORC Vectorized 6640 / 6748 2.4 422.2 0.8X +Native ORC Vectorized (Pushdown) 2438 / 2479 6.5 155.0 2.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9224 / 9356 1.7 586.5 1.0X +Parquet Vectorized (Pushdown) 7172 / 7415 2.2 456.0 1.3X +Native ORC Vectorized 11017 / 11408 1.4 700.4 0.8X +Native ORC Vectorized (Pushdown) 8771 / 10218 1.8 557.7 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 13933 / 15990 1.1 885.8 1.0X +Parquet Vectorized (Pushdown) 12683 / 12942 1.2 806.4 1.1X +Native ORC Vectorized 16344 / 20196 1.0 1039.1 0.9X +Native ORC Vectorized (Pushdown) 15162 / 16627 1.0 964.0 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7102 / 8282 2.2 451.5 1.0X +Parquet Vectorized (Pushdown) 124 / 150 126.4 7.9 57.1X +Native ORC Vectorized 5811 / 6883 2.7 369.5 1.2X +Native ORC Vectorized (Pushdown) 1121 / 1502 14.0 71.3 6.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6894 / 7562 2.3 438.3 1.0X +Parquet Vectorized (Pushdown) 1863 / 1980 8.4 118.4 3.7X +Native ORC Vectorized 6812 / 6848 2.3 433.1 1.0X +Native ORC Vectorized (Pushdown) 2511 / 2598 6.3 159.7 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11732 / 12183 1.3 745.9 1.0X +Parquet Vectorized (Pushdown) 8912 / 9945 1.8 566.6 1.3X +Native ORC Vectorized 11499 / 12387 1.4 731.1 1.0X +Native ORC Vectorized (Pushdown) 9328 / 9382 1.7 593.1 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16272 / 16328 1.0 1034.6 1.0X +Parquet Vectorized (Pushdown) 15714 / 18100 1.0 999.1 1.0X +Native ORC Vectorized 16539 / 18897 1.0 1051.5 1.0X +Native ORC Vectorized (Pushdown) 16328 / 17306 1.0 1038.1 1.0X + + +================================================================================================ +Pushdown benchmark for InSet -> InFilters +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7993 / 8104 2.0 508.2 1.0X +Parquet Vectorized (Pushdown) 507 / 532 31.0 32.2 15.8X +Native ORC Vectorized 6922 / 7163 2.3 440.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7855 / 7963 2.0 499.4 1.0X +Parquet Vectorized (Pushdown) 503 / 516 31.3 32.0 15.6X +Native ORC Vectorized 6825 / 6954 2.3 433.9 1.2X +Native ORC Vectorized (Pushdown) 1019 / 1044 15.4 64.8 7.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7858 / 7928 2.0 499.6 1.0X +Parquet Vectorized (Pushdown) 490 / 519 32.1 31.1 16.0X +Native ORC Vectorized 7079 / 7966 2.2 450.1 1.1X +Native ORC Vectorized (Pushdown) 1276 / 1673 12.3 81.1 6.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8007 / 11155 2.0 509.0 1.0X +Parquet Vectorized (Pushdown) 519 / 540 30.3 33.0 15.4X +Native ORC Vectorized 6848 / 7072 2.3 435.4 1.2X +Native ORC Vectorized (Pushdown) 1026 / 1050 15.3 65.2 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7876 / 7956 2.0 500.7 1.0X +Parquet Vectorized (Pushdown) 521 / 535 30.2 33.1 15.1X +Native ORC Vectorized 7051 / 7368 2.2 448.3 1.1X +Native ORC Vectorized (Pushdown) 1014 / 1035 15.5 64.5 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7897 / 8229 2.0 502.1 1.0X +Parquet Vectorized (Pushdown) 513 / 530 30.7 32.6 15.4X +Native ORC Vectorized 6730 / 6990 2.3 427.9 1.2X +Native ORC Vectorized (Pushdown) 1003 / 1036 15.7 63.8 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7967 / 8175 2.0 506.5 1.0X +Parquet Vectorized (Pushdown) 8155 / 8434 1.9 518.5 1.0X +Native ORC Vectorized 7002 / 7107 2.2 445.2 1.1X +Native ORC Vectorized (Pushdown) 1092 / 1139 14.4 69.4 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8032 / 8122 2.0 510.7 1.0X +Parquet Vectorized (Pushdown) 8141 / 8908 1.9 517.6 1.0X +Native ORC Vectorized 7140 / 7387 2.2 454.0 1.1X +Native ORC Vectorized (Pushdown) 1156 / 1220 13.6 73.5 6.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8088 / 8350 1.9 514.2 1.0X +Parquet Vectorized (Pushdown) 8629 / 8702 1.8 548.6 0.9X +Native ORC Vectorized 7480 / 7886 2.1 475.6 1.1X +Native ORC Vectorized (Pushdown) 1106 / 1145 14.2 70.3 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8028 / 8165 2.0 510.4 1.0X +Parquet Vectorized (Pushdown) 8349 / 8674 1.9 530.8 1.0X +Native ORC Vectorized 7107 / 7354 2.2 451.8 1.1X +Native ORC Vectorized (Pushdown) 1175 / 1207 13.4 74.7 6.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8041 / 8195 2.0 511.2 1.0X +Parquet Vectorized (Pushdown) 8466 / 8604 1.9 538.2 0.9X +Native ORC Vectorized 7116 / 7286 2.2 452.4 1.1X +Native ORC Vectorized (Pushdown) 1197 / 1214 13.1 76.1 6.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7998 / 8311 2.0 508.5 1.0X +Parquet Vectorized (Pushdown) 9366 / 11257 1.7 595.5 0.9X +Native ORC Vectorized 7856 / 9273 2.0 499.5 1.0X +Native ORC Vectorized (Pushdown) 1350 / 1747 11.7 85.8 5.9X + + +================================================================================================ +Pushdown benchmark for tinyint +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3461 / 3997 4.5 220.1 1.0X +Parquet Vectorized (Pushdown) 270 / 315 58.4 17.1 12.8X +Native ORC Vectorized 4107 / 5372 3.8 261.1 0.8X +Native ORC Vectorized (Pushdown) 778 / 1553 20.2 49.5 4.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4771 / 6655 3.3 303.3 1.0X +Parquet Vectorized (Pushdown) 1322 / 1606 11.9 84.0 3.6X +Native ORC Vectorized 4437 / 4572 3.5 282.1 1.1X +Native ORC Vectorized (Pushdown) 1781 / 1976 8.8 113.2 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7433 / 7752 2.1 472.6 1.0X +Parquet Vectorized (Pushdown) 5863 / 5913 2.7 372.8 1.3X +Native ORC Vectorized 7986 / 8084 2.0 507.7 0.9X +Native ORC Vectorized (Pushdown) 6522 / 6608 2.4 414.6 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11190 / 11519 1.4 711.4 1.0X +Parquet Vectorized (Pushdown) 10861 / 11206 1.4 690.5 1.0X +Native ORC Vectorized 11622 / 12196 1.4 738.9 1.0X +Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X + + +================================================================================================ +Pushdown benchmark for Timestamp +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as INT96 row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4784 / 4956 3.3 304.2 1.0X +Parquet Vectorized (Pushdown) 4838 / 4917 3.3 307.6 1.0X +Native ORC Vectorized 3923 / 4173 4.0 249.4 1.2X +Native ORC Vectorized (Pushdown) 894 / 943 17.6 56.8 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as INT96 rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5686 / 5901 2.8 361.5 1.0X +Parquet Vectorized (Pushdown) 5555 / 5895 2.8 353.2 1.0X +Native ORC Vectorized 4844 / 4957 3.2 308.0 1.2X +Native ORC Vectorized (Pushdown) 2141 / 2230 7.3 136.1 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as INT96 rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9100 / 9421 1.7 578.6 1.0X +Parquet Vectorized (Pushdown) 9122 / 9496 1.7 580.0 1.0X +Native ORC Vectorized 8365 / 8874 1.9 531.9 1.1X +Native ORC Vectorized (Pushdown) 7128 / 7376 2.2 453.2 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as INT96 rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12764 / 13120 1.2 811.5 1.0X +Parquet Vectorized (Pushdown) 12656 / 13003 1.2 804.7 1.0X +Native ORC Vectorized 13096 / 13233 1.2 832.6 1.0X +Native ORC Vectorized (Pushdown) 12710 / 15611 1.2 808.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MICROS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4381 / 4796 3.6 278.5 1.0X +Parquet Vectorized (Pushdown) 122 / 137 129.3 7.7 36.0X +Native ORC Vectorized 3913 / 3988 4.0 248.8 1.1X +Native ORC Vectorized (Pushdown) 905 / 945 17.4 57.6 4.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5145 / 5184 3.1 327.1 1.0X +Parquet Vectorized (Pushdown) 1426 / 1519 11.0 90.7 3.6X +Native ORC Vectorized 4827 / 4901 3.3 306.9 1.1X +Native ORC Vectorized (Pushdown) 2133 / 2210 7.4 135.6 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9234 / 9516 1.7 587.1 1.0X +Parquet Vectorized (Pushdown) 6752 / 7046 2.3 429.3 1.4X +Native ORC Vectorized 8418 / 8998 1.9 535.2 1.1X +Native ORC Vectorized (Pushdown) 7199 / 7314 2.2 457.7 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12414 / 12458 1.3 789.2 1.0X +Parquet Vectorized (Pushdown) 12094 / 12249 1.3 768.9 1.0X +Native ORC Vectorized 12198 / 13755 1.3 775.5 1.0X +Native ORC Vectorized (Pushdown) 12205 / 12431 1.3 776.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MILLIS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4369 / 4515 3.6 277.8 1.0X +Parquet Vectorized (Pushdown) 116 / 125 136.2 7.3 37.8X +Native ORC Vectorized 3965 / 4703 4.0 252.1 1.1X +Native ORC Vectorized (Pushdown) 892 / 1162 17.6 56.7 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5211 / 5409 3.0 331.3 1.0X +Parquet Vectorized (Pushdown) 1427 / 1438 11.0 90.7 3.7X +Native ORC Vectorized 4719 / 4883 3.3 300.1 1.1X +Native ORC Vectorized (Pushdown) 2191 / 2228 7.2 139.3 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8716 / 8953 1.8 554.2 1.0X +Parquet Vectorized (Pushdown) 6632 / 6968 2.4 421.7 1.3X +Native ORC Vectorized 8376 / 9118 1.9 532.5 1.0X +Native ORC Vectorized (Pushdown) 7218 / 7609 2.2 458.9 1.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12264 / 12452 1.3 779.7 1.0X +Parquet Vectorized (Pushdown) 11766 / 11927 1.3 748.0 1.0X +Native ORC Vectorized 12101 / 12301 1.3 769.3 1.0X +Native ORC Vectorized (Pushdown) 11983 / 12651 1.3 761.9 1.0X + diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ef41837f89d68..ba17f5f33f2b6 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.5.9 + 2.7.3 jar @@ -118,7 +118,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.scalacheck @@ -146,19 +146,6 @@ parquet-avro test - - - org.apache.avro - avro - 1.8.1 - test - org.mockito mockito-core diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c7c4c7b3e7715..c8cf44b51df77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param taskContext the current task context. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. */ @@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - TaskMemoryManager taskMemoryManager, + TaskContext taskContext, int initialCapacity, long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); + this.map = new BytesToBytesMap( + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). + taskContext.addTaskCompletionListener(context -> { + free(); + }); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 12f4d658b1868..9bfad1e83ee7b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -136,7 +136,7 @@ public int getInt(int rowId) { public long getLong(int rowId) { int index = getRowIndex(rowId); if (isTimestamp) { - return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000 % 1000; } else { return longData.vector[index]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index dcebdc39f0aa2..a0d9578a377b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -497,7 +497,7 @@ private void putValues( * Returns the number of micros since epoch from an element of TimestampColumnVector. */ private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { - return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + return vector.time[index] * 1000 + (vector.nanos[index] / 1000 % 1000); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index e65cd252c3ddf..c975e52734e01 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet; -import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; @@ -147,7 +146,8 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); this.reader = new ParquetFileReader( configuration, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } @@ -225,7 +225,8 @@ protected void initialize(String path, List columns) throws IOException this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } } @@ -293,7 +294,7 @@ protected static IntIterator createRLEIterator( return new RLEIntIterator( new RunLengthBitPackingHybridDecoder( BytesUtils.getWidthFromMaxInt(maxLevel), - new ByteArrayInputStream(bytes.toByteArray()))); + bytes.toInputStream())); } catch (IOException e) { throw new IOException("could not read levels in page for col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 72f1d024b08ce..ba26b57567e64 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.TimeZone; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -165,6 +167,8 @@ void readBatch(int total, WritableColumnVector column) throws IOException { leftInPage = (int) (endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); + PrimitiveType.PrimitiveTypeName typeName = + descriptor.getPrimitiveType().getPrimitiveTypeName(); if (isCurrentPageDictionaryEncoded) { // Read and decode dictionary ids. defColumn.readIntegers( @@ -173,12 +177,12 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. if (column.hasDictionary() || (rowId == 0 && - (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || - (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && + (typeName == PrimitiveType.PrimitiveTypeName.INT32 || + (typeName == PrimitiveType.PrimitiveTypeName.INT64 && originalType != OriginalType.TIMESTAMP_MILLIS) || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { + typeName == PrimitiveType.PrimitiveTypeName.FLOAT || + typeName == PrimitiveType.PrimitiveTypeName.DOUBLE || + typeName == PrimitiveType.PrimitiveTypeName.BINARY))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some // non-dictionary encoded values have already been added). @@ -193,7 +197,7 @@ void readBatch(int total, WritableColumnVector column) throws IOException { decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); } column.setDictionary(null); - switch (descriptor.getType()) { + switch (typeName) { case BOOLEAN: readBooleanBatch(rowId, num, column); break; @@ -216,10 +220,11 @@ void readBatch(int total, WritableColumnVector column) throws IOException { readBinaryBatch(rowId, num, column); break; case FIXED_LEN_BYTE_ARRAY: - readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength()); + readFixedLenByteArrayBatch( + rowId, num, column, descriptor.getPrimitiveType().getTypeLength()); break; default: - throw new IOException("Unsupported type: " + descriptor.getType()); + throw new IOException("Unsupported type: " + typeName); } } @@ -241,8 +246,8 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc WritableColumnVector column) { return new SchemaColumnConvertNotSupportedException( Arrays.toString(descriptor.getPath()), - descriptor.getType().toString(), - column.dataType().toString()); + descriptor.getPrimitiveType().getPrimitiveTypeName().toString(), + column.dataType().catalogString()); } /** @@ -253,7 +258,7 @@ private void decodeDictionaryIds( int num, WritableColumnVector column, WritableColumnVector dictionaryIds) { - switch (descriptor.getType()) { + switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) { case INT32: if (column.dataType() == DataTypes.IntegerType || DecimalType.is32BitDecimalType(column.dataType())) { @@ -379,7 +384,8 @@ private void decodeDictionaryIds( break; default: - throw new UnsupportedOperationException("Unsupported type: " + descriptor.getType()); + throw new UnsupportedOperationException( + "Unsupported type: " + descriptor.getPrimitiveType().getPrimitiveTypeName()); } } @@ -388,7 +394,8 @@ private void decodeDictionaryIds( * is guaranteed that num is smaller than the number of values left in the current page. */ - private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { + private void readBooleanBatch(int rowId, int num, WritableColumnVector column) + throws IOException { if (column.dataType() != DataTypes.BooleanType) { throw constructConvertNotSupportedException(descriptor, column); } @@ -396,7 +403,7 @@ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } - private void readIntBatch(int rowId, int num, WritableColumnVector column) { + private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || @@ -414,7 +421,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { } } - private void readLongBatch(int rowId, int num, WritableColumnVector column) { + private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType()) || @@ -434,7 +441,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } - private void readFloatBatch(int rowId, int num, WritableColumnVector column) { + private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: support implicit cast to double? if (column.dataType() == DataTypes.FloatType) { @@ -445,7 +452,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { } } - private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { + private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.DoubleType) { @@ -456,7 +463,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { } } - private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { + private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; @@ -556,7 +563,7 @@ public Void visit(DataPageV2 dataPageV2) { }); } - private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { + private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) throws IOException { this.endOfPageValueCount = valuesRead + pageValueCount; if (dataEncoding.usesDictionary()) { this.dataColumn = null; @@ -581,7 +588,7 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } try { - dataColumn.initFromPage(pageValueCount, bytes, offset); + dataColumn.initFromPage(pageValueCount, in); } catch (IOException e) { throw new IOException("could not read page in col " + descriptor, e); } @@ -602,12 +609,11 @@ private void readPageV1(DataPageV1 page) throws IOException { this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { - byte[] bytes = page.getBytes().toByteArray(); - rlReader.initFromPage(pageValueCount, bytes, 0); - int next = rlReader.getNextOffset(); - dlReader.initFromPage(pageValueCount, bytes, next); - next = dlReader.getNextOffset(); - initDataReader(page.getValueEncoding(), bytes, next); + BytesInput bytes = page.getBytes(); + ByteBufferInputStream in = bytes.toInputStream(); + rlReader.initFromPage(pageValueCount, in); + dlReader.initFromPage(pageValueCount, in); + initDataReader(page.getValueEncoding(), in); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } @@ -619,12 +625,13 @@ private void readPageV2(DataPageV2 page) throws IOException { page.getRepetitionLevels(), descriptor); int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); - this.defColumn = new VectorizedRleValuesReader(bitWidth); + // do not read the length from the stream. v2 pages handle dividing the page bytes. + this.defColumn = new VectorizedRleValuesReader(bitWidth, false); this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); - this.defColumn.initFromBuffer( - this.pageValueCount, page.getDefinitionLevels().toByteArray()); + this.defColumn.initFromPage( + this.pageValueCount, page.getDefinitionLevels().toInputStream()); try { - initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); + initDataReader(page.getDataEncoding(), page.getData().toInputStream()); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 5934a23db8af1..f02861355c404 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -270,21 +270,23 @@ public boolean nextBatch() throws IOException { private void initializeInternal() throws IOException, UnsupportedOperationException { // Check that the requested schema is supported. missingColumns = new boolean[requestedSchema.getFieldCount()]; + List columns = requestedSchema.getColumns(); + List paths = requestedSchema.getPaths(); for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { Type t = requestedSchema.getFields().get(i); if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { throw new UnsupportedOperationException("Complex types not supported."); } - String[] colPath = requestedSchema.getPaths().get(i); + String[] colPath = paths.get(i); if (fileSchema.containsPath(colPath)) { ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); - if (!fd.equals(requestedSchema.getColumns().get(i))) { + if (!fd.equals(columns.get(i))) { throw new UnsupportedOperationException("Schema evolution not supported."); } missingColumns[i] = false; } else { - if (requestedSchema.getColumns().get(i).getMaxDefinitionLevel() == 0) { + if (columns.get(i).getMaxDefinitionLevel() == 0) { // Column is missing in data but the required data is non-nullable. This file is invalid. throw new IOException("Required column is missing in data file. Col: " + Arrays.toString(colPath)); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 5b75f719339fb..c62dc3d86386e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -20,8 +20,9 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.io.ParquetDecodingException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import org.apache.spark.unsafe.Platform; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; @@ -30,24 +31,18 @@ * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. */ public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { - private byte[] buffer; - private int offset; - private int bitOffset; // Only used for booleans. - private ByteBuffer byteBuffer; // used to wrap the byte array buffer + private ByteBufferInputStream in = null; - private static final boolean bigEndianPlatform = - ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + // Only used for booleans. + private int bitOffset; + private byte currentByte = 0; public VectorizedPlainValuesReader() { } @Override - public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException { - this.buffer = bytes; - this.offset = offset + Platform.BYTE_ARRAY_OFFSET; - if (bigEndianPlatform) { - byteBuffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); - } + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; } @Override @@ -63,115 +58,157 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) { } } + private ByteBuffer getBuffer(int length) { + try { + return in.slice(length).order(ByteOrder.LITTLE_ENDIAN); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read " + length + " bytes", e); + } + } + @Override public final void readIntegers(int total, WritableColumnVector c, int rowId) { - c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putIntsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putInt(rowId + i, buffer.getInt()); + } + } } @Override public final void readLongs(int total, WritableColumnVector c, int rowId) { - c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putLongsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putLong(rowId + i, buffer.getLong()); + } + } } @Override public final void readFloats(int total, WritableColumnVector c, int rowId) { - c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putFloats(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putFloat(rowId + i, buffer.getFloat()); + } + } } @Override public final void readDoubles(int total, WritableColumnVector c, int rowId) { - c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putDoubles(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putDouble(rowId + i, buffer.getDouble()); + } + } } @Override public final void readBytes(int total, WritableColumnVector c, int rowId) { - for (int i = 0; i < total; i++) { - // Bytes are stored as a 4-byte little endian int. Just read the first byte. - // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. - c.putByte(rowId + i, Platform.getByte(buffer, offset)); - offset += 4; + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + for (int i = 0; i < total; i += 1) { + c.putByte(rowId + i, buffer.get()); + // skip the next 3 bytes + buffer.position(buffer.position() + 3); } } @Override public final boolean readBoolean() { - byte b = Platform.getByte(buffer, offset); - boolean v = (b & (1 << bitOffset)) != 0; + // TODO: vectorize decoding and keep boolean[] instead of currentByte + if (bitOffset == 0) { + try { + currentByte = (byte) in.read(); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read a byte", e); + } + } + + boolean v = (currentByte & (1 << bitOffset)) != 0; bitOffset += 1; if (bitOffset == 8) { bitOffset = 0; - offset++; } return v; } @Override public final int readInteger() { - int v = Platform.getInt(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Integer.reverseBytes(v); - } - offset += 4; - return v; + return getBuffer(4).getInt(); } @Override public final long readLong() { - long v = Platform.getLong(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Long.reverseBytes(v); - } - offset += 8; - return v; + return getBuffer(8).getLong(); } @Override public final byte readByte() { - return (byte)readInteger(); + return (byte) readInteger(); } @Override public final float readFloat() { - float v; - if (!bigEndianPlatform) { - v = Platform.getFloat(buffer, offset); - } else { - v = byteBuffer.getFloat(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 4; - return v; + return getBuffer(4).getFloat(); } @Override public final double readDouble() { - double v; - if (!bigEndianPlatform) { - v = Platform.getDouble(buffer, offset); - } else { - v = byteBuffer.getDouble(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 8; - return v; + return getBuffer(8).getDouble(); } @Override public final void readBinary(int total, WritableColumnVector v, int rowId) { for (int i = 0; i < total; i++) { int len = readInteger(); - int start = offset; - offset += len; - v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + v.putByteArray(rowId + i, buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + v.putByteArray(rowId + i, bytes); + } } } @Override public final Binary readBinary(int len) { - Binary result = Binary.fromConstantByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len); - offset += len; - return result; + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + return Binary.fromConstantByteArray( + buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + return Binary.fromConstantByteArray(bytes); + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index fc7fa70c39419..fe3d31ae8e746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import org.apache.parquet.Preconditions; +import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; @@ -27,6 +28,9 @@ import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import java.io.IOException; +import java.nio.ByteBuffer; + /** * A values reader for Parquet's run-length encoded data. This is based off of the version in * parquet-mr with these changes: @@ -49,9 +53,7 @@ private enum MODE { } // Encoded data. - private byte[] in; - private int end; - private int offset; + private ByteBufferInputStream in; // bit/byte width of decoded data and utility to batch unpack them. private int bitWidth; @@ -70,45 +72,40 @@ private enum MODE { // If true, the bit width is fixed. This decoder is used in different places and this also // controls if we need to read the bitwidth from the beginning of the data stream. private final boolean fixedWidth; + private final boolean readLength; public VectorizedRleValuesReader() { - fixedWidth = false; + this.fixedWidth = false; + this.readLength = false; } public VectorizedRleValuesReader(int bitWidth) { - fixedWidth = true; + this.fixedWidth = true; + this.readLength = bitWidth != 0; + init(bitWidth); + } + + public VectorizedRleValuesReader(int bitWidth, boolean readLength) { + this.fixedWidth = true; + this.readLength = readLength; init(bitWidth); } @Override - public void initFromPage(int valueCount, byte[] page, int start) { - this.offset = start; - this.in = page; + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; if (fixedWidth) { - if (bitWidth != 0) { + // initialize for repetition and definition levels + if (readLength) { int length = readIntLittleEndian(); - this.end = this.offset + length; + this.in = in.sliceStream(length); } } else { - this.end = page.length; - if (this.end != this.offset) init(page[this.offset++] & 255); - } - if (bitWidth == 0) { - // 0 bit width, treat this as an RLE run of valueCount number of 0's. - this.mode = MODE.RLE; - this.currentCount = valueCount; - this.currentValue = 0; - } else { - this.currentCount = 0; + // initialize for values + if (in.available() > 0) { + init(in.read()); + } } - } - - // Initialize the reader from a buffer. This is used for the V2 page encoding where the - // definition are in its own buffer. - public void initFromBuffer(int valueCount, byte[] data) { - this.offset = 0; - this.in = data; - this.end = data.length; if (bitWidth == 0) { // 0 bit width, treat this as an RLE run of valueCount number of 0's. this.mode = MODE.RLE; @@ -129,11 +126,6 @@ private void init(int bitWidth) { this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth); } - @Override - public int getNextOffset() { - return this.end; - } - @Override public boolean readBoolean() { return this.readInteger() != 0; @@ -182,7 +174,7 @@ public void readIntegers( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -217,7 +209,7 @@ public void readBooleans( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -251,7 +243,7 @@ public void readBytes( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -285,7 +277,7 @@ public void readShorts( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -321,7 +313,7 @@ public void readLongs( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -355,7 +347,7 @@ public void readFloats( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -389,7 +381,7 @@ public void readDoubles( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -423,7 +415,7 @@ public void readBinarys( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -462,7 +454,7 @@ public void readIntegers( WritableColumnVector nulls, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -559,12 +551,12 @@ public Binary readBinary(int len) { /** * Reads the next varint encoded int. */ - private int readUnsignedVarInt() { + private int readUnsignedVarInt() throws IOException { int value = 0; int shift = 0; int b; do { - b = in[offset++] & 255; + b = in.read(); value |= (b & 0x7F) << shift; shift += 7; } while ((b & 0x80) != 0); @@ -574,35 +566,32 @@ private int readUnsignedVarInt() { /** * Reads the next 4 byte little endian int. */ - private int readIntLittleEndian() { - int ch4 = in[offset] & 255; - int ch3 = in[offset + 1] & 255; - int ch2 = in[offset + 2] & 255; - int ch1 = in[offset + 3] & 255; - offset += 4; + private int readIntLittleEndian() throws IOException { + int ch4 = in.read(); + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } /** * Reads the next byteWidth little endian int. */ - private int readIntLittleEndianPaddedOnBitWidth() { + private int readIntLittleEndianPaddedOnBitWidth() throws IOException { switch (bytesWidth) { case 0: return 0; case 1: - return in[offset++] & 255; + return in.read(); case 2: { - int ch2 = in[offset] & 255; - int ch1 = in[offset + 1] & 255; - offset += 2; + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 8) + ch2; } case 3: { - int ch3 = in[offset] & 255; - int ch2 = in[offset + 1] & 255; - int ch1 = in[offset + 2] & 255; - offset += 3; + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); } case 4: { @@ -619,32 +608,36 @@ private int ceil8(int value) { /** * Reads the next group. */ - private void readNextGroup() { - int header = readUnsignedVarInt(); - this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; - switch (mode) { - case RLE: - this.currentCount = header >>> 1; - this.currentValue = readIntLittleEndianPaddedOnBitWidth(); - return; - case PACKED: - int numGroups = header >>> 1; - this.currentCount = numGroups * 8; - int bytesToRead = ceil8(this.currentCount * this.bitWidth); - - if (this.currentBuffer.length < this.currentCount) { - this.currentBuffer = new int[this.currentCount]; - } - currentBufferIdx = 0; - int valueIndex = 0; - for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { - this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); - valueIndex += 8; - } - offset += bytesToRead; - return; - default: - throw new ParquetDecodingException("not a valid mode " + this.mode); + private void readNextGroup() { + try { + int header = readUnsignedVarInt(); + this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; + switch (mode) { + case RLE: + this.currentCount = header >>> 1; + this.currentValue = readIntLittleEndianPaddedOnBitWidth(); + return; + case PACKED: + int numGroups = header >>> 1; + this.currentCount = numGroups * 8; + + if (this.currentBuffer.length < this.currentCount) { + this.currentBuffer = new int[this.currentCount]; + } + currentBufferIdx = 0; + int valueIndex = 0; + while (valueIndex < this.currentCount) { + // values are bit packed 8 at a time, so reading bitWidth will always work + ByteBuffer buffer = in.slice(bitWidth); + this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex); + valueIndex += 8; + } + return; + default: + throw new ParquetDecodingException("not a valid mode " + this.mode); + } + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read from input stream", e); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4733f36174f42..6fdadde628551 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -216,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) { @Override public void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(null, data + 2L * rowId, value); } @Override public void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data + 2L * rowId; for (int i = 0; i < count; ++i, offset += 2) { Platform.putShort(null, offset, value); } @@ -229,20 +229,20 @@ public void putShorts(int rowId, int count, short value) { @Override public void putShorts(int rowId, int count, short[] src, int srcIndex) { - Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L, + null, data + 2L * rowId, count * 2L); } @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 2, count * 2); + null, data + rowId * 2L, count * 2L); } @Override public short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(null, data + 2L * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -252,7 +252,7 @@ public short getShort(int rowId) { public short[] getShorts(int rowId, int count) { assert(dictionary == null); short[] array = new short[count]; - Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); return array; } @@ -262,12 +262,12 @@ public short[] getShorts(int rowId, int count) { @Override public void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(null, data + 4L * rowId, value); } @Override public void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putInt(null, offset, value); } @@ -275,24 +275,24 @@ public void putInts(int rowId, int count, int value) { @Override public void putInts(int rowId, int count, int[] src, int srcIndex) { - Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + null, data + 4L * rowId, count * 4L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { Platform.putInt(null, offset, java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); @@ -303,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -313,7 +313,7 @@ public int getInt(int rowId) { public int[] getInts(int rowId, int count) { assert(dictionary == null); int[] array = new int[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); return array; } @@ -325,7 +325,7 @@ public int[] getInts(int rowId, int count) { public int getDictId(int rowId) { assert(dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } // @@ -334,12 +334,12 @@ public int getDictId(int rowId) { @Override public void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(null, data + 8L * rowId, value); } @Override public void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putLong(null, offset, value); } @@ -347,24 +347,24 @@ public void putLongs(int rowId, int count, long value) { @Override public void putLongs(int rowId, int count, long[] src, int srcIndex) { - Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + null, data + 8L * rowId, count * 8L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { Platform.putLong(null, offset, java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); @@ -375,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(null, data + 8L * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); } @@ -385,7 +385,7 @@ public long getLong(int rowId) { public long[] getLongs(int rowId, int count) { assert(dictionary == null); long[] array = new long[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); return array; } @@ -395,12 +395,12 @@ public long[] getLongs(int rowId, int count) { @Override public void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(null, data + rowId * 4L, value); } @Override public void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, value); } @@ -408,18 +408,18 @@ public void putFloats(int rowId, int count, float value) { @Override public void putFloats(int rowId, int count, float[] src, int srcIndex) { - Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); } @@ -429,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(null, data + rowId * 4L); } else { return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId)); } @@ -439,7 +439,7 @@ public float getFloat(int rowId) { public float[] getFloats(int rowId, int count) { assert(dictionary == null); float[] array = new float[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); return array; } @@ -450,12 +450,12 @@ public float[] getFloats(int rowId, int count) { @Override public void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(null, data + rowId * 8L, value); } @Override public void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, value); } @@ -463,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) { @Override public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); } @@ -484,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(null, data + rowId * 8L); } else { return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); } @@ -494,7 +494,7 @@ public double getDouble(int rowId) { public double[] getDoubles(int rowId, int count) { assert(dictionary == null); double[] array = new double[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L); return array; } @@ -504,26 +504,26 @@ public double[] getDoubles(int rowId, int count) { @Override public void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, offset); } @Override public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(null, lengthData + 4L * rowId); } @Override public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(null, offsetData + 4L * rowId); } // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, result); return result; } @@ -533,19 +533,19 @@ protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; if (isArray() || type instanceof MapType) { this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 23dcc104e67c4..577eab6ed14c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, - Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L); } @Override @@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, - Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L); } @Override @@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, - Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L); } @Override @@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { @@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5275e4a91eac0..b0e119d658cb4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -81,7 +81,9 @@ public void close() { } public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { + if (requiredCapacity < 0) { + throwUnsupportedException(requiredCapacity, null); + } else if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); if (requiredCapacity <= newCapacity) { try { @@ -96,13 +98,16 @@ public void reserve(int requiredCapacity) { } private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader, or increase the vectorized reader batch size. For parquet file " + - "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + - SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; + String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + + "vectorized reader. For parquet file format, refer to " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java new file mode 100644 index 0000000000000..f403dc619e86c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for batch processing. + * + * This interface is used to create {@link BatchReadSupport} instances when end users run + * {@code SparkSession.read.format(...).option(...).load()}. + */ +@InterfaceStability.Evolving +public interface BatchReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link BatchReadSupport} instance to load the data from this data source with a user + * specified schema, which is called by Spark at the beginning of each batch query. + * + * Spark will call this method at the beginning of each batch query to create a + * {@link BatchReadSupport} instance. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user specified schema. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link BatchReadSupport} instance to scan the data from this data source, which is + * called by Spark at the beginning of each batch query. + * + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + BatchReadSupport createBatchReadSupport(DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java similarity index 58% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java index cab56453816cc..bd10c3353bf12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java @@ -21,32 +21,39 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability and save the data to the data source. + * provide data writing ability for batch processing. + * + * This interface is used to create {@link BatchWriteSupport} instances when end users run + * {@code Dataset.write.format(...).option(...).save()}. */ @InterfaceStability.Evolving -public interface WriteSupport extends DataSourceV2 { +public interface BatchWriteSupportProvider extends DataSourceV2 { /** - * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done according to the save mode. + * Creates an optional {@link BatchWriteSupport} instance to save the data to this data source, + * which is called by Spark at the beginning of each batch query. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was - * submitted. + * Data sources can return None if there is no writing needed to be done according to the save + * mode. * - * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceWriter} can - * use this job id to distinguish itself from other jobs. + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link BatchWriteSupport} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. + * @return a write support to write data to this data source. */ - Optional createWriter( - String jobId, StructType schema, SaveMode mode, DataSourceOptions options); + Optional createBatchWriteSupport( + String queryId, + StructType schema, + SaveMode mode, + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java deleted file mode 100644 index 7df5a451ae5f3..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for continuous stream processing. - */ -@InterfaceStability.Evolving -public interface ContinuousReadSupport extends DataSourceV2 { - /** - * Creates a {@link ContinuousReader} to scan the data from this data source. - * - * @param schema the user provided schema, or empty() if none was provided - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - ContinuousReader createContinuousReader( - Optional schema, - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java new file mode 100644 index 0000000000000..824c290518acf --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for continuous stream processing. + * + * This interface is used to create {@link ContinuousReadSupport} instances when end users run + * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data + * source with a user specified schema, which is called by Spark at the beginning of each + * continuous streaming query. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user provided schema. + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default ContinuousReadSupport createContinuousReadSupport( + StructType schema, + String checkpointLocation, + DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data + * source, which is called by Spark at the beginning of each continuous streaming query. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + ContinuousReadSupport createContinuousReadSupport( + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java similarity index 58% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java index 0ea4dc6b5def3..7011a70e515e2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java @@ -18,23 +18,16 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; /** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability and scan the data from the data source. + * An interface for reporting custom metrics from streaming sources and sinks */ @InterfaceStability.Evolving -public interface ReadSupport extends DataSourceV2 { - +public interface CustomMetrics { /** - * Creates a {@link DataSourceReader} to scan the data from this data source. - * - * If this method fails (by throwing an exception), the action would fail and no Spark job was - * submitted. + * Returns a JSON serialized representation of custom metrics * - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. + * @return JSON serialized representation of custom metrics */ - DataSourceReader createReader(DataSourceOptions options); + String json(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6234071320dc9..6e31e84bf6c72 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -22,9 +22,13 @@ /** * The base interface for data source v2. Implementations must have a public, 0-arg constructor. * - * Note that this is an empty interface. Data source implementations should mix-in at least one of - * the plug-in interfaces like {@link ReadSupport} and {@link WriteSupport}. Otherwise it's just - * a dummy data source which is un-readable/writable. + * Note that this is an empty interface. Data source implementations must mix in interfaces such as + * {@link BatchReadSupportProvider} or {@link BatchWriteSupportProvider}, which can provide + * batch or streaming read/write support instances. Otherwise it's just a dummy data source which + * is un-readable/writable. + * + * If Spark fails to execute any methods in the implementations of this interface (by throwing an + * exception), the read action will fail and no Spark job will be submitted. */ @InterfaceStability.Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java deleted file mode 100644 index 209ffa7a0b9fa..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide streaming micro-batch data reading ability. - */ -@InterfaceStability.Evolving -public interface MicroBatchReadSupport extends DataSourceV2 { - /** - * Creates a {@link MicroBatchReader} to read batches of data from this data source in a - * streaming query. - * - * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and - * then call stop() when the execution is complete. Note that a single query may have multiple - * executions due to restart or failure recovery. - * - * @param schema the user provided schema, or empty() if none was provided - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - MicroBatchReader createMicroBatchReader( - Optional schema, - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java new file mode 100644 index 0000000000000..61c08e7fa89df --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for micro-batch stream processing. + * + * This interface is used to create {@link MicroBatchReadSupport} instances when end users run + * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data + * source with a user specified schema, which is called by Spark at the beginning of each + * micro-batch streaming query. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user provided schema. + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default MicroBatchReadSupport createMicroBatchReadSupport( + StructType schema, + String checkpointLocation, + DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data + * source, which is called by Spark at the beginning of each micro-batch streaming query. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + MicroBatchReadSupport createMicroBatchReadSupport( + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java deleted file mode 100644 index 3801402268af1..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability and scan the data from the data source. - * - * This is a variant of {@link ReadSupport} that accepts user-specified schema when reading data. - * A data source can implement both {@link ReadSupport} and {@link ReadSupportWithSchema} if it - * supports both schema inference and user-specified schema. - */ -@InterfaceStability.Evolving -public interface ReadSupportWithSchema extends DataSourceV2 { - - /** - * Create a {@link DataSourceReader} to scan the data from this data source. - * - * If this method fails (by throwing an exception), the action would fail and no Spark job was - * submitted. - * - * @param schema the full schema of this data source reader. Full schema usually maps to the - * physical schema of the underlying storage of this data source reader, e.g. - * CSV files, JSON files, etc, while this reader may not read data with full - * schema, as column pruning or other optimizations may happen. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - DataSourceReader createReader(StructType schema, DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 9d66805d79b9e..bbe430e299261 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -27,10 +27,11 @@ @InterfaceStability.Evolving public interface SessionConfigSupport extends DataSourceV2 { - /** - * Key prefix of the session configs to propagate. Spark will extract all session configs that - * starts with `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` - * into `xxx -> yyy`, and propagate them to all data source operations in this session. - */ - String keyPrefix(); + /** + * Key prefix of the session configs to propagate, which is usually the data source name. Spark + * will extract all session configs that starts with `spark.datasource.$keyPrefix`, turn + * `spark.datasource.$keyPrefix.xxx -> yyy` into `xxx -> yyy`, and propagate them to all + * data source operations in this session. + */ + String keyPrefix(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java deleted file mode 100644 index a77b01497269e..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - */ -@InterfaceStability.Evolving -public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { - - /** - * Creates an optional {@link StreamWriter} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link DataSourceWriter} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamWriter createStreamWriter( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java new file mode 100644 index 0000000000000..f9ca85d8089b4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability for structured streaming. + * + * This interface is used to create {@link StreamingWriteSupport} instances when end users run + * {@code Dataset.writeStream.format(...).option(...).start()}. + */ +@InterfaceStability.Evolving +public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { + + /** + * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is + * called by Spark at the beginning of each streaming query. + * + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link StreamingWriteSupport} can use this id to distinguish itself from others. + * @param schema the schema of the data to be written. + * @param mode the output mode which determines what successive epoch output means to this + * sink, please refer to {@link OutputMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + StreamingWriteSupport createStreamingWriteSupport( + String queryId, + StructType schema, + OutputMode mode, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java new file mode 100644 index 0000000000000..452ee86675b42 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface that defines how to load the data from data source for batch processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.BatchReadSupportProvider}) at the start of a batch + * query, then call {@link #newScanConfigBuilder()} and create an instance of {@link ScanConfig}. + * The {@link ScanConfigBuilder} can apply operator pushdown and keep the pushdown result in + * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader + * factory to scan data from the data source with a Spark job. + */ +@InterfaceStability.Evolving +public interface BatchReadSupport extends ReadSupport { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, and keep these + * information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link BatchReadSupport} needs + * to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java deleted file mode 100644 index a470bccc5aad2..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; -import org.apache.spark.sql.types.StructType; - -/** - * A data source reader that is returned by - * {@link ReadSupport#createReader(DataSourceOptions)} or - * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. - * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link DataReaderFactory}s that are returned by - * {@link #createDataReaderFactories()}. - * - * There are mainly 3 kinds of query optimizations: - * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column - * pruning), etc. Names of these interfaces start with `SupportsPushDown`. - * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. - * Names of these interfaces start with `SupportsReporting`. - * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. Note that a reader should only - * implement at most one of the special scans, if more than one special scans are implemented, - * only one of them would be respected, according to the priority list from high to low: - * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. - * - * If an exception was throw when applying any of these query optimizations, the action would fail - * and no Spark job was submitted. - * - * Spark first applies all operator push-down optimizations that this data source supports. Then - * Spark collects information this data source reported for further optimizations. Finally Spark - * issues the scan request and does the actual data reading. - */ -@InterfaceStability.Evolving -public interface DataSourceReader { - - /** - * Returns the actual schema of this data source reader, which may be different from the physical - * schema of the underlying storage, as column pruning or other optimizations may happen. - * - * If this method fails (by throwing an exception), the action would fail and no Spark job was - * submitted. - */ - StructType readSchema(); - - /** - * Returns a list of reader factories. Each factory is responsible for creating a data reader to - * output data for one RDD partition. That means the number of factories returned here is same as - * the number of RDD partitions this scan outputs. - * - * Note that, this may not be a full scan if the data source reader mixes in other optimization - * interfaces like column pruning, filter push-down, etc. These optimizations are applied before - * Spark issues the scan request. - * - * If this method fails (by throwing an exception), the action would fail and no Spark job was - * submitted. - */ - List> createDataReaderFactories(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java similarity index 52% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 32e98e8f5d8bd..95c30de907e44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,40 +22,33 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is - * responsible for creating the actual data reader. The relationship between - * {@link DataReaderFactory} and {@link DataReader} - * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. + * A serializable representation of an input partition returned by + * {@link ReadSupport#planInputPartitions(ScanConfig)}. * - * Note that, the reader factory will be serialized and sent to executors, then the data reader - * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be - * serializable and {@link DataReader} doesn't need to be. + * Note that {@link InputPartition} will be serialized and sent to executors, then + * {@link PartitionReader} will be created by + * {@link PartitionReaderFactory#createReader(InputPartition)} or + * {@link PartitionReaderFactory#createColumnarReader(InputPartition)} on executors to do + * the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader} + * doesn't need to be. */ @InterfaceStability.Evolving -public interface DataReaderFactory extends Serializable { +public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this reader factory can run faster, - * but Spark does not guarantee to run the data reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run + * faster, but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in - * the returned locations. By default this method returns empty string array, which means this - * task has no location preference. + * the returned locations. The default return value is empty string array, which means this + * input partition's reader has no location preference. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ default String[] preferredLocations() { return new String[0]; } - - /** - * Returns a data reader to do the actual reading work. - * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - */ - DataReader createDataReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java similarity index 68% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java index bb9790a1c819e..04ff8d0a19fc3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java @@ -23,31 +23,27 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for + * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or + * {@link PartitionReaderFactory#createColumnarReader(InputPartition)}. It's responsible for * outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source - * readers that mix in {@link SupportsScanUnsafeRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} + * for normal data sources, or {@link org.apache.spark.sql.vectorized.ColumnarBatch} for columnar + * data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)} + * returns true). */ @InterfaceStability.Evolving -public interface DataReader extends Closeable { +public interface PartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - * * @throws IOException if failure happens during disk/network IO like reading files. */ boolean next() throws IOException; /** * Return the current record. This method should return same value until `next` is called. - * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. */ T get(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java new file mode 100644 index 0000000000000..f35de9310eee3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A factory used to create {@link PartitionReader} instances. + * + * If Spark fails to execute any methods in the implementations of this interface or in the returned + * {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ +@InterfaceStability.Evolving +public interface PartitionReaderFactory extends Serializable { + + /** + * Returns a row-based partition reader to read data from the given {@link InputPartition}. + * + * Implementations probably need to cast the input partition to the concrete + * {@link InputPartition} class defined for the data source. + */ + PartitionReader createReader(InputPartition partition); + + /** + * Returns a columnar partition reader to read data from the given {@link InputPartition}. + * + * Implementations probably need to cast the input partition to the concrete + * {@link InputPartition} class defined for the data source. + */ + default PartitionReader createColumnarReader(InputPartition partition) { + throw new UnsupportedOperationException("Cannot create columnar reader."); + } + + /** + * Returns true if the given {@link InputPartition} should be read by Spark in a columnar way. + * This means, implementations must also implement {@link #createColumnarReader(InputPartition)} + * for the input partitions that this method returns true. + * + * As of Spark 2.4, Spark can only read all input partition in a columnar way, or none of them. + * Data source can't mix columnar and row-based partitions. This may be relaxed in future + * versions. + */ + default boolean supportColumnarReads(InputPartition partition) { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java new file mode 100644 index 0000000000000..a58ddb288f1ed --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * The base interface for all the batch and streaming read supports. Data sources should implement + * concrete read support interfaces like {@link BatchReadSupport}. + * + * If Spark fails to execute any methods in the implementations of this interface (by throwing an + * exception), the read action will fail and no Spark job will be submitted. + */ +@InterfaceStability.Evolving +public interface ReadSupport { + + /** + * Returns the full schema of this data source, which is usually the physical schema of the + * underlying storage. This full schema should not be affected by column pruning or other + * optimizations. + */ + StructType fullSchema(); + + /** + * Returns a list of {@link InputPartition input partitions}. Each {@link InputPartition} + * represents a data split that can be processed by one Spark task. The number of input + * partitions returned here is the same as the number of RDD partitions this scan outputs. + * + * Note that, this may not be a full scan if the data source supports optimization like filter + * push-down. Implementations should check the input {@link ScanConfig} and adjust the resulting + * {@link InputPartition input partitions}. + */ + InputPartition[] planInputPartitions(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java new file mode 100644 index 0000000000000..7462ce2820585 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * An interface that carries query specific information for the data scanning job, like operator + * pushdown information and streaming query offsets. This is defined as an empty interface, and data + * sources should define their own {@link ScanConfig} classes. + * + * For APIs that take a {@link ScanConfig} as input, like + * {@link ReadSupport#planInputPartitions(ScanConfig)}, + * {@link BatchReadSupport#createReaderFactory(ScanConfig)} and + * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to + * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. + */ +@InterfaceStability.Evolving +public interface ScanConfig { + + /** + * Returns the actual schema of this data source reader, which may be different from the physical + * schema of the underlying storage, as column pruning or other optimizations may happen. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + StructType readSchema(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java similarity index 62% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java index a61697649c43e..4c0eedfddfe22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java @@ -18,18 +18,13 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can - * implement this interface to provide creating {@link DataReader} with particular offset. + * An interface for building the {@link ScanConfig}. Implementations can mixin those + * SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in + * the returned {@link ScanConfig}. */ @InterfaceStability.Evolving -public interface ContinuousDataReaderFactory extends DataReaderFactory { - /** - * Create a DataReader with particular offset as its startOffset. - * - * @param offset offset want to set as the DataReader's startOffset. - */ - DataReader createDataReaderWithOffset(PartitionOffset offset); +public interface ScanConfigBuilder { + ScanConfig build(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index e8cd7adbca071..44799c7d49137 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -23,7 +23,7 @@ /** * An interface to represent statistics for a data source, which is returned by - * {@link SupportsReportStatistics#getStatistics()}. + * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. */ @InterfaceStability.Evolving public interface Statistics { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java deleted file mode 100644 index 290d614805ac7..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.expressions.Expression; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to push down arbitrary expressions as predicates to the data source. - * This is an experimental and unstable interface as {@link Expression} is not public and may get - * changed in the future Spark versions. - * - * Note that, if data source readers implement both this interface and - * {@link SupportsPushDownFilters}, Spark will ignore {@link SupportsPushDownFilters} and only - * process this interface. - */ -@InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters extends DataSourceReader { - - /** - * Pushes down filters, and returns filters that need to be evaluated after scanning. - */ - Expression[] pushCatalystFilters(Expression[] filters); - - /** - * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. - * It's possible that there is no filters in the query and - * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for - * this case. - */ - Expression[] pushedCatalystFilters(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 1cff024232a44..5e7985f645a06 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,15 +21,11 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to push down filters to the data source and reduce the size of the data to be read. - * - * Note that, if data source readers implement both this interface and - * {@link SupportsPushDownCatalystFilters}, Spark will ignore this interface and only process - * {@link SupportsPushDownCatalystFilters}. + * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to + * push down filters to the data source and reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters extends DataSourceReader { +public interface SupportsPushDownFilters extends ScanConfigBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. @@ -37,7 +33,15 @@ public interface SupportsPushDownFilters extends DataSourceReader { Filter[] pushFilters(Filter[] filters); /** - * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} * is never called, empty array should be returned for this case. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index 427b4d00a1128..edb164937d6ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns extends DataSourceReader { +public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,8 +35,8 @@ public interface SupportsPushDownRequiredColumns extends DataSourceReader { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, data source readers should update {@link DataSourceReader#readSchema()} after - * applying column pruning. + * Note that, {@link ScanConfig#readSchema()} implementation should take care of the column + * pruning applied here. */ void pruneColumns(StructType requiredSchema); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 607628746e873..db62cd4515362 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -21,17 +21,17 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report data partitioning and try to avoid shuffle at Spark side. + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid - * adding a shuffle even if the reader does not implement this interface. + * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, + * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning extends DataSourceReader { +public interface SupportsReportPartitioning extends ReadSupport { /** * Returns the output data partitioning that this reader guarantees. */ - Partitioning outputPartitioning(); + Partitioning outputPartitioning(ScanConfig config); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 11bb13fd3b211..1831488ba096f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,14 +20,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report statistics to Spark. + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report statistics to Spark. + * + * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the + * data source. Implementations that return more accurate statistics based on pushed operators will + * not improve query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics extends DataSourceReader { +public interface SupportsReportStatistics extends ReadSupport { /** - * Returns the basic statistics of this data source. + * Returns the estimated statistics of this data source scan. */ - Statistics getStatistics(); + Statistics estimateStatistics(ScanConfig config); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java deleted file mode 100644 index 2e5cfa78511f0..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link ColumnarBatch} and make the scan faster. - */ -@InterfaceStability.Evolving -public interface SupportsScanColumnarBatch extends DataSourceReader { - @Override - default List> createDataReaderFactories() { - throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); - } - - /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data - * in batches. - */ - List> createBatchDataReaderFactories(); - - /** - * Returns true if the concrete data source reader can read data in batch according to the scan - * properties like required columns, pushes filters, etc. It's possible that the implementation - * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. - */ - default boolean enableBatchRead() { - return true; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java deleted file mode 100644 index 9cd749e8e4ce9..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. - * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get - * changed in the future Spark versions. - */ -@InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceReader { - - @Override - default List> createDataReaderFactories() { - throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); - } - - /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, - * but returns data in unsafe row format. - */ - List> createUnsafeRowReaderFactories(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 2d0ee50212b56..6764d4b7665c7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link DataReader}. + * {@link PartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index f6b111fdf220d..364a3f553923c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,13 +18,14 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * be distributed among the data partitions (one {@link PartitionReader} outputs data for one + * partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link DataReader}). + * partition(the output records of a single {@link PartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 309d9e5de0a0f..fb0b6f1df43bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -18,20 +18,21 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by - * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a - * snapshot. Once created, it should be deterministic and always report the same number of + * {@link SupportsReportPartitioning#outputPartitioning(ScanConfig)}. Note that this should work + * like a snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ @InterfaceStability.Evolving public interface Partitioning { /** - * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. + * Returns the number of partitions(i.e., {@link InputPartition}s) the data source outputs. */ int numPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java similarity index 61% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java index 47d26440841fd..9101c8a44d34e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java @@ -18,19 +18,20 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** - * A variation on {@link DataReader} for use with streaming in continuous processing mode. + * A variation on {@link PartitionReader} for use with continuous streaming processing. */ @InterfaceStability.Evolving -public interface ContinuousDataReader extends DataReader { - /** - * Get the offset of the current record, or the start offset if no records have been read. - * - * The execution engine will call this method along with get() to keep track of the current - * offset. When an epoch ends, the offset of the previous record in each partition will be saved - * as a restart checkpoint. - */ - PartitionOffset getOffset(); +public interface ContinuousPartitionReader extends PartitionReader { + + /** + * Get the offset of the current record, or the start offset if no records have been read. + * + * The execution engine will call this method along with get() to keep track of the current + * offset. When an epoch ends, the offset of the previous record in each partition will be saved + * as a restart checkpoint. + */ + PartitionOffset getOffset(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java similarity index 52% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java index d2cf7e01c08c8..2d9f1ca1686a1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java @@ -15,27 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.writer; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this - * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. - * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get - * changed in the future Spark versions. + * A variation on {@link PartitionReaderFactory} that returns {@link ContinuousPartitionReader} + * instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for + * continuous streaming processing. */ - -@InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceWriter { +@InterfaceStability.Evolving +public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { + @Override + ContinuousPartitionReader createReader(InputPartition partition); @Override - default DataWriterFactory createWriterFactory() { - throw new IllegalStateException( - "createWriterFactory should not be called with SupportsWriteInternalRow."); + default ContinuousPartitionReader createColumnarReader(InputPartition partition) { + throw new UnsupportedOperationException("Cannot create columnar reader."); } - - DataWriterFactory createInternalRowWriterFactory(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java new file mode 100644 index 0000000000000..9a3ad2eb8a801 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; + +/** + * An interface that defines how to load the data from data source for continuous streaming + * processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.ContinuousReadSupportProvider}) at the start of a + * streaming query, then call {@link #newScanConfigBuilder(Offset)} and create an instance of + * {@link ScanConfig} for the duration of the streaming query or until + * {@link #needsReconfiguration(ScanConfig)} is true. The {@link ScanConfig} will be used to create + * input partitions and reader factory to scan data with a Spark job for its duration. At the end + * {@link #stop()} will be called when the streaming execution is completed. Note that a single + * query may have multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, store streaming + * offsets, etc., and keep these information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link ContinuousReadSupport} + * needs to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(Offset start); + + /** + * Returns a factory, which produces one {@link ContinuousPartitionReader} for one + * {@link InputPartition}. + */ + ContinuousPartitionReaderFactory createContinuousReaderFactory(ScanConfig config); + + /** + * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances + * for each partition to a single global offset. + */ + Offset mergeOffsets(PartitionOffset[] offsets); + + /** + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. + * + * If true, the query will be shut down and restarted with a new {@link ContinuousReadSupport} + * instance. + */ + default boolean needsReconfiguration(ScanConfig config) { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java deleted file mode 100644 index 7fe7f00ac2fa8..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; - -import java.util.Optional; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to allow reading in a continuous processing mode stream. - * - * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. - * - * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with - * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. - */ -@InterfaceStability.Evolving -public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { - /** - * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each - * partition to a single global offset. - */ - Offset mergeOffsets(PartitionOffset[] offsets); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Set the desired start offset for reader factories created from this reader. The scan will - * start from the first record after the provided offset, or from an implementation-defined - * inferred starting point if no offset is provided. - */ - void setStartOffset(Optional start); - - /** - * Return the specified or inferred start offset for this reader. - * - * @throws IllegalStateException if setStartOffset has not been called - */ - Offset getStartOffset(); - - /** - * The execution engine will call this method in every epoch to determine if new reader - * factories need to be generated, which may be required if for example the underlying - * source system has had partitions added or removed. - * - * If true, the query will be shut down and restarted with a new reader. - */ - default boolean needsReconfiguration() { - return false; - } - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java new file mode 100644 index 0000000000000..edb0db11bff2c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.*; + +/** + * An interface that defines how to scan the data from data source for micro-batch streaming + * processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.MicroBatchReadSupportProvider}) at the start of a + * streaming query, then call {@link #newScanConfigBuilder(Offset, Offset)} and create an instance + * of {@link ScanConfig} for each micro-batch. The {@link ScanConfig} will be used to create input + * partitions and reader factory to scan a micro-batch with a Spark job. At the end {@link #stop()} + * will be called when the streaming execution is completed. Note that a single query may have + * multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, store streaming + * offsets, etc., and keep these information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link MicroBatchReadSupport} + * needs to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(Offset start, Offset end); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(ScanConfig config); + + /** + * Returns the most recent offset available. + */ + Offset latestOffset(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java deleted file mode 100644 index 67ebde30d61a9..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; - -import java.util.Optional; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to indicate they allow micro-batch streaming reads. - * - * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with - * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. - */ -@InterfaceStability.Evolving -public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { - /** - * Set the desired offset range for reader factories created from this reader. Reader factories - * will generate only data within (`start`, `end`]; that is, from the first record after `start` - * to the record with offset `end`. - * - * @param start The initial offset to scan from. If not specified, scan from an - * implementation-specified start point, such as the earliest available record. - * @param end The last offset to include in the scan. If not specified, scan up to an - * implementation-defined endpoint, such as the last available offset - * or the start offset plus a target batch size. - */ - void setOffsetRange(Optional start, Optional end); - - /** - * Returns the specified (if explicitly set through setOffsetRange) or inferred start offset - * for this reader. - * - * @throws IllegalStateException if setOffsetRange has not been called - */ - Offset getStartOffset(); - - /** - * Return the specified (if explicitly set through setOffsetRange) or inferred end offset - * for this reader. - * - * @throws IllegalStateException if setOffsetRange has not been called - */ - Offset getEndOffset(); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index e41c0351edc82..6cf27734867cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -20,8 +20,8 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An abstract representation of progress through a {@link MicroBatchReader} or - * {@link ContinuousReader}. + * An abstract representation of progress through a {@link MicroBatchReadSupport} or + * {@link ContinuousReadSupport}. * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java new file mode 100644 index 0000000000000..84872d1ebc26e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.sql.sources.v2.reader.ReadSupport; + +/** + * A base interface for streaming read support. This is package private and is invisible to data + * sources. Data sources should implement concrete streaming read support interfaces: + * {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. + */ +interface StreamingReadSupport extends ReadSupport { + + /** + * Returns the initial offset for a streaming query to start reading from. Note that the + * streaming data source should not assume that it will start reading from its initial offset: + * if Spark is restarting an existing query, it will restart from the check-pointed offset rather + * than the initial one. + */ + Offset initialOffset(); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java new file mode 100644 index 0000000000000..8693154cb7045 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.CustomMetrics; + +/** + * A mix in interface for {@link StreamingReadSupport}. Data sources can implement this interface + * to report custom metrics that gets reported under the + * {@link org.apache.spark.sql.streaming.SourceProgress} + */ +@InterfaceStability.Evolving +public interface SupportsCustomReaderMetrics extends StreamingReadSupport { + + /** + * Returns custom metrics specific to this data source. + */ + CustomMetrics getCustomMetrics(); + + /** + * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is invalid + * (e.g. Invalid data that cannot be parsed). Throwing an error here would ensure that + * your custom metrics work right and correct values are reported always. The default action + * on invalid metrics is to ignore it. + * + * @param ex the exception + */ + default void onInvalidMetrics(Exception ex) { + // default is to ignore invalid custom metrics + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java similarity index 75% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java index 0a0fd8db58035..0ec9e05d6a02b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java @@ -18,28 +18,13 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.StreamWriteSupport; -import org.apache.spark.sql.sources.v2.WriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; /** - * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ - * {@link StreamWriteSupport#createStreamWriter( - * String, StructType, OutputMode, DataSourceOptions)}. - * It can mix in various writing optimization interfaces to speed up the data saving. The actual - * writing logic is delegated to {@link DataWriter}. - * - * If an exception was throw when applying any of these writing optimizations, the action would fail - * and no Spark job was submitted. + * An interface that defines how to write the data to data source for batch processing. * * The writing procedure is: - * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the - * partitions of the input data(RDD). + * 1. Create a writer factory by {@link #createBatchWriterFactory()}, serialize and send it to all + * the partitions of the input data(RDD). * 2. For each partition, create the data writer, and write the data of the partition with this * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If * exception happens during the writing, call {@link DataWriter#abort()}. @@ -53,19 +38,19 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface DataSourceWriter { +public interface BatchWriteSupport { /** * Creates a writer factory which will be serialized and sent to executors. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - DataWriterFactory createWriterFactory(); + DataWriterFactory createBatchWriterFactory(); /** - * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for - * each task commits. + * Returns whether Spark should use the commit coordinator to ensure that at most one task for + * each partition commits. * * @return true if commit coordinator should be used, false otherwise. */ @@ -90,9 +75,9 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * * Note that speculative execution may cause multiple tasks to run for a partition. By default, - * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * Spark uses the commit coordinator to allow at most one task to commit. Implementations can * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple - * attempts may have committed successfully and one successful commit message per task will be + * tasks may have committed successfully and one successful commit message per task will be * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 39bf458298862..5fb067966ee67 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is + * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -36,26 +36,24 @@ * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data + * {@link BatchWriteSupport#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a - * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createWriter(int, long)} will receive a + * different `taskId`. Spark will call {@link BatchWriteSupport#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the * previous one fails, speculative tasks are running simultaneously. It's possible that one input - * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * RDD partition has multiple data writers with different `taskId` running at the same time, * and data sources should guarantee that these data writers don't conflict and can work together. * Implementations can coordinate with driver during {@link #commit()} to make sure only one of * these data writers can commit successfully. Or implementations can allow all of them to commit * successfully, and have a way to revert committed data writers without the commit message, because * Spark only accepts the commit message that arrives first and ignore others. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers - * that mix in {@link SupportsWriteInternalRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ @InterfaceStability.Evolving public interface DataWriter { @@ -73,11 +71,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link DataSourceWriter#commit(WriterCommitMessage[])}. + * {@link BatchWriteSupport#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to + * {@link BatchWriteSupport#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link BatchWriteSupport} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -95,7 +93,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link BatchWriteSupport#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index c2c2ab73257e8..19a36dd232456 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -19,38 +19,37 @@ import java.io.Serializable; +import org.apache.spark.TaskContext; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; /** - * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link BatchWriteSupport#createBatchWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer - * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataWriterFactory extends Serializable { +public interface DataWriterFactory extends Serializable { /** - * Returns a data writer to do the actual writing work. + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was - * submitted. + * If this method fails (by throwing an exception), the corresponding Spark write task would fail + * and get retried until hitting the maximum retry times. * * @param partitionId A unique id of the RDD partition that the returned writer will process. * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task - * failed, Spark launches a new task wth the same task id but different - * attempt number. Or a task is too slow, Spark launches new tasks wth the - * same task id but different attempt number, which means there are multiple - * tasks with the same task id running at the same time. Implementations can - * use this attempt number to distinguish writers of different task attempts. - * @param epochId A monotonically increasing id for streaming queries that are split in to - * discrete periods of execution. For non-streaming queries, - * this ID will always be 0. + * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run + * multiple tasks for the same partition (due to speculation or task failures, + * for example). */ - DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); + DataWriter createWriter(int partitionId, long taskId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 9e38836c0edf9..123335c414e9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -19,15 +19,16 @@ import java.io.Serializable; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; import org.apache.spark.annotation.InterfaceStability; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. + * as the input parameter of {@link BatchWriteSupport#commit(WriterCommitMessage[])} or + * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. * - * This is an empty interface, data sources should define their own message class and use it in - * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} - * implementations. + * This is an empty interface, data sources should define their own message class and use it when + * generating messages at executor side and handling the messages at driver side. */ @InterfaceStability.Evolving public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java new file mode 100644 index 0000000000000..a4da24fc5ae68 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer.streaming; + +import java.io.Serializable; + +import org.apache.spark.TaskContext; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.writer.DataWriter; + +/** + * A factory of {@link DataWriter} returned by + * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating + * and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So this interface must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface StreamingDataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. + * + * If this method fails (by throwing an exception), the corresponding Spark write task would fail + * and get retried until hitting the maximum retry times. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run + * multiple tasks for the same partition (due to speculation or task failures, + * for example). + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. + */ + DataWriter createWriter(int partitionId, long taskId, long epochId); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java similarity index 78% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java index a316b2a4c1d82..3fdfac5e1c84a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java @@ -18,27 +18,36 @@ package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceWriter} for use with structured streaming. + * An interface that defines how to write the data to data source for streaming processing. * * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving -public interface StreamWriter extends DataSourceWriter { +public interface StreamingWriteSupport { + + /** + * Creates a writer factory which will be serialized and sent to executors. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + StreamingDataWriterFactory createStreamingWriterFactory(); + /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by * {@link DataWriter#commit()}. * * If this method fails (by throwing an exception), this writing job is considered to have been - * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + * failed, and the execution engine will attempt to call + * {@link #abort(long, WriterCommitMessage[])}. * - * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * The execution engine may call `commit` multiple times for the same epoch in some circumstances. * To support exactly-once data semantics, implementations must ensure that multiple commits for * the same epoch are idempotent. */ @@ -46,7 +55,8 @@ public interface StreamWriter extends DataSourceWriter { /** * Aborts this writing job because some data writers are failed and keep failing when retried, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * the Spark job fails with some unknown reasons, or {@link #commit(long, WriterCommitMessage[])} + * fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. @@ -58,14 +68,4 @@ public interface StreamWriter extends DataSourceWriter { * clean up the data left by data writers. */ void abort(long epochId, WriterCommitMessage[] messages); - - default void commit(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Commit without epoch should not be called with StreamWriter"); - } - - default void abort(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Abort without epoch should not be called with StreamWriter"); - } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java new file mode 100644 index 0000000000000..2b018c7d123bb --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.CustomMetrics; + +/** + * A mix in interface for {@link StreamingWriteSupport}. Data sources can implement this interface + * to report custom metrics that gets reported under the + * {@link org.apache.spark.sql.streaming.SinkProgress} + */ +@InterfaceStability.Evolving +public interface SupportsCustomWriterMetrics extends StreamingWriteSupport { + + /** + * Returns custom metrics specific to this data source. + */ + CustomMetrics getCustomMetrics(); + + /** + * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is invalid + * (e.g. Invalid data that cannot be parsed). Throwing an error here would ensure that + * your custom metrics work right and correct values are reported always. The default action + * on invalid metrics is to ignore it. + * + * @param ex the exception + */ + default void onInvalidMetrics(Exception ex) { + // default is to ignore invalid custom metrics + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 227a16f7e69e9..1c9beda404356 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -162,13 +162,13 @@ public ArrowColumnVector(ValueVector vector) { } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - } else if (vector instanceof NullableMapVector) { - NullableMapVector mapVector = (NullableMapVector) vector; - accessor = new StructAccessor(mapVector); + } else if (vector instanceof StructVector) { + StructVector structVector = (StructVector) vector; + accessor = new StructAccessor(structVector); - childColumns = new ArrowColumnVector[mapVector.size()]; + childColumns = new ArrowColumnVector[structVector.size()]; for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); + childColumns[i] = new ArrowColumnVector(structVector.getVectorById(i)); } } else { throw new UnsupportedOperationException(); @@ -455,9 +455,9 @@ final boolean isNullAt(int rowId) { @Override final ColumnarArray getArray(int rowId) { ArrowBuf offsets = accessor.getOffsetBuffer(); - int index = rowId * accessor.OFFSET_WIDTH; + int index = rowId * ListVector.OFFSET_WIDTH; int start = offsets.getInt(index); - int end = offsets.getInt(index + accessor.OFFSET_WIDTH); + int end = offsets.getInt(index + ListVector.OFFSET_WIDTH); return new ColumnarArray(arrayData, start, end - start); } } @@ -472,7 +472,7 @@ final ColumnarArray getArray(int rowId) { */ private static class StructAccessor extends ArrowVectorAccessor { - StructAccessor(NullableMapVector vector) { + StructAccessor(StructVector vector) { super(vector); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad0efbae89830..ae27690f2e5ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability @@ -103,7 +104,7 @@ class TypedColumn[-T, U]( * * {{{ * df("columnName") // On a specific `df` DataFrame. - * col("columnName") // A generic column no yet associated with a DataFrame. + * col("columnName") // A generic column not yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. @@ -344,7 +345,7 @@ class Column(val expr: Expression) extends Logging { * * // Java: * import static org.apache.spark.sql.functions.*; - * people.select( people("age").gt(21) ); + * people.select( people.col("age").gt(21) ); * }}} * * @group expr_ops @@ -360,7 +361,7 @@ class Column(val expr: Expression) extends Logging { * * // Java: * import static org.apache.spark.sql.functions.*; - * people.select( people("age").gt(21) ); + * people.select( people.col("age").gt(21) ); * }}} * * @group java_expr_ops @@ -375,7 +376,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") < 21 ) * * // Java: - * people.select( people("age").lt(21) ); + * people.select( people.col("age").lt(21) ); * }}} * * @group expr_ops @@ -390,7 +391,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") < 21 ) * * // Java: - * people.select( people("age").lt(21) ); + * people.select( people.col("age").lt(21) ); * }}} * * @group java_expr_ops @@ -405,7 +406,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") <= 21 ) * * // Java: - * people.select( people("age").leq(21) ); + * people.select( people.col("age").leq(21) ); * }}} * * @group expr_ops @@ -420,7 +421,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") <= 21 ) * * // Java: - * people.select( people("age").leq(21) ); + * people.select( people.col("age").leq(21) ); * }}} * * @group java_expr_ops @@ -435,7 +436,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") >= 21 ) * * // Java: - * people.select( people("age").geq(21) ) + * people.select( people.col("age").geq(21) ) * }}} * * @group expr_ops @@ -450,7 +451,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") >= 21 ) * * // Java: - * people.select( people("age").geq(21) ) + * people.select( people.col("age").geq(21) ) * }}} * * @group java_expr_ops @@ -587,7 +588,7 @@ class Column(val expr: Expression) extends Logging { * people.filter( people("inSchool") || people("isEmployed") ) * * // Java: - * people.filter( people("inSchool").or(people("isEmployed")) ); + * people.filter( people.col("inSchool").or(people.col("isEmployed")) ); * }}} * * @group expr_ops @@ -602,7 +603,7 @@ class Column(val expr: Expression) extends Logging { * people.filter( people("inSchool") || people("isEmployed") ) * * // Java: - * people.filter( people("inSchool").or(people("isEmployed")) ); + * people.filter( people.col("inSchool").or(people.col("isEmployed")) ); * }}} * * @group java_expr_ops @@ -617,7 +618,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("inSchool") && people("isEmployed") ) * * // Java: - * people.select( people("inSchool").and(people("isEmployed")) ); + * people.select( people.col("inSchool").and(people.col("isEmployed")) ); * }}} * * @group expr_ops @@ -632,7 +633,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("inSchool") && people("isEmployed") ) * * // Java: - * people.select( people("inSchool").and(people("isEmployed")) ); + * people.select( people.col("inSchool").and(people.col("isEmployed")) ); * }}} * * @group java_expr_ops @@ -647,7 +648,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") + people("weight") ) * * // Java: - * people.select( people("height").plus(people("weight")) ); + * people.select( people.col("height").plus(people.col("weight")) ); * }}} * * @group expr_ops @@ -662,7 +663,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") + people("weight") ) * * // Java: - * people.select( people("height").plus(people("weight")) ); + * people.select( people.col("height").plus(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -677,7 +678,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") - people("weight") ) * * // Java: - * people.select( people("height").minus(people("weight")) ); + * people.select( people.col("height").minus(people.col("weight")) ); * }}} * * @group expr_ops @@ -692,7 +693,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") - people("weight") ) * * // Java: - * people.select( people("height").minus(people("weight")) ); + * people.select( people.col("height").minus(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -707,7 +708,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") * people("weight") ) * * // Java: - * people.select( people("height").multiply(people("weight")) ); + * people.select( people.col("height").multiply(people.col("weight")) ); * }}} * * @group expr_ops @@ -722,7 +723,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") * people("weight") ) * * // Java: - * people.select( people("height").multiply(people("weight")) ); + * people.select( people.col("height").multiply(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -737,7 +738,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") / people("weight") ) * * // Java: - * people.select( people("height").divide(people("weight")) ); + * people.select( people.col("height").divide(people.col("weight")) ); * }}} * * @group expr_ops @@ -752,7 +753,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") / people("weight") ) * * // Java: - * people.select( people("height").divide(people("weight")) ); + * people.select( people.col("height").divide(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -780,12 +781,54 @@ class Column(val expr: Expression) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. * + * Note: Since the type of the elements in the list are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * * @group expr_ops * @since 1.5.0 */ @scala.annotation.varargs def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * Note: Since the type of the elements in the collection are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * + * @group expr_ops + * @since 2.4.0 + */ + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * Note: Since the type of the elements in the collection are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * + * @group java_expr_ops + * @since 2.4.0 + */ + def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) + /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f3a2b70657c48..5288907b7d7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -494,6 +494,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType case (BooleanType, dt) => dt == BooleanType + case _ => + throw new IllegalArgumentException(s"$targetType is not matched at fillValue") } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d640fdc530ce2..0cfcc45fb3d31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,6 +22,7 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper +import com.univocity.parsers.csv.CsvParser import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability @@ -36,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, DataSourceOptions, DataSourceV2} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -193,7 +194,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + if (ds.isInstanceOf[BatchReadSupportProvider]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) val pathsOption = { @@ -257,7 +258,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. "fetchsize" can be used to control the - * number of rows per fetch. + * number of rows per fetch and "queryTimeout" can be used to wait + * for a Statement object to execute to the given number of seconds. * @since 1.4.0 */ def jdbc( @@ -372,8 +374,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
    • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
    • + *
    • `encoding` (by default it is not set): allows to forcibly set one of standard basic + * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding + * is not specified and `multiLine` is set to `true`, it will be detected automatically.
    • *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
    • + *
    • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used + * for schema inferring.
    • + *
    • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
    • * * * @since 2.0.0 @@ -441,7 +450,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => rawParser.parse(input, createParser, UTF8String.fromString), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming) @@ -468,12 +478,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * it determines the columns as string types and it reads only the first line to determine the * names and the number of fields. * + * If the enforceSchema is set to `false`, only the CSV header in the first line is checked + * to conform specified or inferred schema. + * * @param csvDataset input Dataset with one CSV row per record * @since 2.2.0 */ def csv(csvDataset: Dataset[String]): DataFrame = { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone) val filteredLines: Dataset[String] = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) @@ -492,6 +506,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + val parser = new CsvParser(parsedOptions.asParserSettings) + val columnNames = parser.parseLine(firstLine) + CSVDataSource.checkHeaderColumnNames( + actualSchema, + columnNames, + csvDataset.getClass.getCanonicalName, + parsedOptions.enforceSchema, + sparkSession.sessionState.conf.caseSensitiveAnalysis) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) @@ -501,7 +523,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => Seq(rawParser.parse(input)), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming) @@ -532,8 +555,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `comment` (default empty string): sets a single character used for skipping lines * beginning with this character. By default, it is disabled.
    • *
    • `header` (default `false`): uses the first line as names of columns.
    • + *
    • `enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema + * will be forcibly applied to datasource files, and headers in CSV files will be ignored. + * If the option is set to `false`, the schema will be validated against all headers in CSV files + * in the case when the `header` option is set to `true`. Field names in the schema + * and column names in CSV headers are checked by their positions taking into account + * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable + * the `enforceSchema` option to avoid incorrect results.
    • *
    • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
    • + *
    • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
    • *
    • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading * whitespaces from values being read should be skipped.
    • *
    • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing @@ -575,6 +606,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
    • *
    • `multiLine` (default `false`): parse one record, which may span multiple lines.
    • * + * * @since 2.0.0 */ @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bbc063148a72c..eca2d5b971905 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import java.text.SimpleDateFormat -import java.util.{Date, Locale, Properties, UUID} +import java.util.{Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -26,12 +25,11 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType @@ -240,21 +238,29 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - ds match { - case ws: WriteSupport => - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = df.sparkSession.sessionState.conf)).asJava) - // Using a timestamp and a random UUID to distinguish different writing jobs. This is good - // enough as there won't be tons of writing jobs created at the same second. - val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) - .format(new Date()) + "-" + UUID.randomUUID() - val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) - if (writer.isPresent) { + val source = cls.newInstance().asInstanceOf[DataSourceV2] + source match { + case provider: BatchWriteSupportProvider => + val options = extraOptions ++ + DataSourceV2Utils.extractSessionConfigs(source, df.sparkSession.sessionState.conf) + + val relation = DataSourceV2Relation.create(source, options.toMap) + if (mode == SaveMode.Append) { runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.logicalPlan) + AppendData.byName(relation, df.logicalPlan) + } + + } else { + val writer = provider.createBatchWriteSupport( + UUID.randomUUID().toString, + df.logicalPlan.output.toStructType, + mode, + new DataSourceOptions(options.asJava)) + + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get, df.logicalPlan) + } } } @@ -275,7 +281,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) } } @@ -330,8 +336,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def getBucketSpec: Option[BucketSpec] = { - if (sortColumnNames.isDefined) { - require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + if (sortColumnNames.isDefined && numBuckets.isEmpty) { + throw new AnalysisException("sortBy must be used together with bucketBy") } numBuckets.map { n => @@ -340,14 +346,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def assertNotBucketed(operation: String): Unit = { - if (numBuckets.isDefined || sortColumnNames.isDefined) { - throw new AnalysisException(s"'$operation' does not support bucketing right now") + if (getBucketSpec.isDefined) { + if (sortColumnNames.isEmpty) { + throw new AnalysisException(s"'$operation' does not support bucketBy right now") + } else { + throw new AnalysisException(s"'$operation' does not support bucketBy and sortBy right now") + } } } private def assertNotPartitioned(operation: String): Unit = { if (partitioningColumns.isDefined) { - throw new AnalysisException( s"'$operation' does not support partitioning") + throw new AnalysisException(s"'$operation' does not support partitioning") } } @@ -518,8 +528,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • - *
    • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
    • + *
    • `encoding` (by default it is not set): specifies encoding (charset) of saved json + * files. If it is not set, the UTF-8 charset will be used.
    • + *
    • `lineSep` (default `\n`): defines the line separator that should be used for writing.
    • * * * @since 1.4.0 @@ -539,8 +550,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
        *
      • `compression` (default is the value specified in `spark.sql.parquet.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive - * shorten names(`none`, `snappy`, `gzip`, and `lzo`). This will override - * `spark.sql.parquet.compression.codec`.
      • + * shorten names(`none`, `uncompressed`, `snappy`, `gzip`, `lzo`, `brotli`, `lz4`, and `zstd`). + * This will override `spark.sql.parquet.compression.codec`. *
      * * @since 1.4.0 @@ -589,8 +600,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • - *
    • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
    • + *
    • `lineSep` (default `\n`): defines the line separator that should be used for writing.
    • * * * @since 1.6.0 @@ -625,6 +635,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * enclosed in quotes. Default is to only escape values containing a quote character. *
    • `header` (default `false`): writes the names of columns as the first line.
    • *
    • `nullValue` (default empty string): sets the string representation of a null value.
    • + *
    • `encoding` (by default it is not set): specifies encoding (charset) of saved csv + * files. If it is not set, the UTF-8 charset will be used.
    • *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 917168162b236..db439b1ee76f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} +import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -65,7 +65,12 @@ private[sql] object Dataset { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) // Eagerly bind the encoder so we verify that the encoder matches the underlying // schema. The user will get an error if this is not the case. - dataset.deserializer + // optimization: it is guaranteed that [[InternalRow]] can be converted to [[Row]] so + // do not do this check in that case. this check can be expensive since it requires running + // the whole [[Analyzer]] to resolve the deserializer + if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) { + dataset.deserializer + } dataset } @@ -195,9 +200,6 @@ class Dataset[T] private[sql]( } } - // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) - /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -231,16 +233,15 @@ class Dataset[T] private[sql]( } /** - * Compose the string representing rows for output + * Get rows represented in Sequence by specific truncate and vertical requirement. * - * @param _numRows Number of rows to show + * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, prints output rows vertically (one line per column value). */ - private[sql] def showString( - _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0).min(Int.MaxValue - 1) + private[sql] def getRows( + numRows: Int, + truncate: Int): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -251,14 +252,12 @@ class Dataset[T] private[sql]( Column(col).cast(StringType) } } - val takeResult = newDf.select(castCols: _*).take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + val data = newDf.select(castCols: _*).take(numRows + 1) // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + schema.fieldNames.toSeq +: data.map { row => row.toSeq.map { cell => val str = cell match { case null => "null" @@ -274,6 +273,26 @@ class Dataset[T] private[sql]( } }: Seq[String] } + } + + /** + * Compose the string representing rows for output + * + * @param _numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). + */ + private[sql] def showString( + _numRows: Int, + truncate: Int = 20, + vertical: Boolean = false): String = { + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. + val tmpRows = getRows(numRows, truncate) + + val hasMoreData = tmpRows.length - 1 > numRows + val rows = tmpRows.take(numRows + 1) val sb = new StringBuilder val numCols = schema.fieldNames.length @@ -291,31 +310,25 @@ class Dataset[T] private[sql]( } } + val paddedRows = rows.map { row => + row.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + } + } + // Create SeparateLine val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - + paddedRows.head.addString(sb, "|", "|", "|\n") sb.append(sep) // data - rows.tail.foreach { - _.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - + paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) sb.append(sep) } else { // Extended display mode enabled @@ -346,7 +359,7 @@ class Dataset[T] private[sql]( } // Print a footer - if (vertical && data.isEmpty) { + if (vertical && rows.tail.isEmpty) { // In a vertical mode, print an empty row set explicitly sb.append("(0 rows)\n") } else if (hasMoreData) { @@ -415,7 +428,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -511,6 +524,16 @@ class Dataset[T] private[sql]( */ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + /** + * Returns true if the `Dataset` is empty. + * + * @group basic + * @since 2.4.0 + */ + def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) == 0 + } + /** * Returns true if this Dataset contains one or more sources that continuously * return data as it arrives. A Dataset that reads data from a streaming source @@ -659,7 +682,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) } /** @@ -832,7 +855,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } /** @@ -910,7 +933,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -971,7 +994,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -980,8 +1003,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed - val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -995,6 +1018,11 @@ class Dataset[T] private[sql]( catalyst.expressions.EqualTo( withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name)) + case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualNullSafe( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} withPlan { @@ -1013,7 +1041,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) + Join(logicalPlan, right.logicalPlan, joinType = Cross, None) } /** @@ -1045,8 +1073,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.planWithBarrier, - other.planWithBarrier, + this.logicalPlan, + other.logicalPlan, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1267,7 +1295,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, planWithBarrier) + SubqueryAlias(alias, logicalPlan) } /** @@ -1305,7 +1333,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), planWithBarrier) + Project(cols.map(_.named), logicalPlan) } /** @@ -1360,8 +1388,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, - planWithBarrier) + val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1379,8 +1406,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) + columns.map(_.withInputType(exprEnc, logicalPlan.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1456,7 +1483,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, planWithBarrier) + Filter(condition.expr, logicalPlan) } /** @@ -1607,7 +1634,9 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def reduce(func: (T, T) => T): T = rdd.reduce(func) + def reduce(func: (T, T) => T): T = withNewRDDExecutionId { + rdd.reduce(func) + } /** * :: Experimental :: @@ -1633,15 +1662,14 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = planWithBarrier - val withGroupingKey = AppendColumns(func, inputPlan) + val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) new KeyValueGroupedDataset( encoderFor[K], encoderFor[T], executed, - inputPlan.output, + logicalPlan.output, withGroupingKey.newColumns) } @@ -1779,7 +1807,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), planWithBarrier) + Limit(Literal(n), logicalPlan) } /** @@ -1829,7 +1857,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -1888,7 +1916,7 @@ class Dataset[T] private[sql]( // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, rightChild)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, rightChild)) } /** @@ -1902,9 +1930,26 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(planWithBarrier, other.planWithBarrier) + Intersect(logicalPlan, other.logicalPlan, isAll = false) + } + + /** + * Returns a new Dataset containing rows only in both this Dataset and another Dataset while + * preserving the duplicates. + * This is equivalent to `INTERSECT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard + * in SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.4.0 + */ + def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator { + Intersect(logicalPlan, other.logicalPlan, isAll = true) } + /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT DISTINCT` in SQL. @@ -1916,7 +1961,23 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier) + Except(logicalPlan, other.logicalPlan, isAll = false) + } + + /** + * Returns a new Dataset containing rows in this Dataset but not in another Dataset while + * preserving the duplicates. + * This is equivalent to `EXCEPT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard in + * SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.4.0 + */ + def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { + Except(logicalPlan, other.logicalPlan, isAll = true) } /** @@ -1967,7 +2028,7 @@ class Dataset[T] private[sql]( */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, planWithBarrier) + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } } @@ -2009,15 +2070,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = planWithBarrier.output + val sortOrder = logicalPlan.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, planWithBarrier) + Sort(sortOrder, global = false, logicalPlan) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - planWithBarrier + logicalPlan } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -2101,7 +2162,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -2142,7 +2203,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -2293,7 +2354,7 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.planWithBarrier.output + val attrs = this.logicalPlan.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) @@ -2341,7 +2402,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, planWithBarrier) + Deduplicate(groupCols, logicalPlan) } /** @@ -2523,7 +2584,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2537,7 +2598,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2551,7 +2612,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, planWithBarrier) + MapElements[T, U](func, logicalPlan) } /** @@ -2566,7 +2627,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, planWithBarrier)) + withTypedPlan(MapElements[T, U](func, logicalPlan)) } /** @@ -2582,7 +2643,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, planWithBarrier), + MapPartitions[T, U](func, logicalPlan), implicitly[Encoder[U]]) } @@ -2613,7 +2674,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) } /** @@ -2777,7 +2838,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, planWithBarrier) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -2800,7 +2861,7 @@ class Dataset[T] private[sql]( |For range partitioning use repartitionByRange(...) instead. """.stripMargin) withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) } } @@ -2838,7 +2899,7 @@ class Dataset[T] private[sql]( case expr: Expression => SortOrder(expr, Ascending) }) withTypedPlan { - RepartitionByExpression(sortOrder, planWithBarrier, numPartitions) + RepartitionByExpression(sortOrder, logicalPlan, numPartitions) } } @@ -2877,7 +2938,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, planWithBarrier) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** @@ -2933,12 +2994,13 @@ class Dataset[T] private[sql]( */ def storageLevel: StorageLevel = { sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => - cachedData.cachedRepresentation.storageLevel + cachedData.cachedRepresentation.cacheBuilder.storageLevel }.getOrElse(StorageLevel.NONE) } /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @param blocking Whether to block until all blocks are deleted. * @@ -2946,12 +3008,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking) + sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, blocking) this } /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @group basic * @since 1.6.0 @@ -2960,7 +3023,7 @@ class Dataset[T] private[sql]( // Represents the `QueryExecution` used to produce the content of the Dataset as an `RDD`. @transient private lazy val rddQueryExecution: QueryExecution = { - val deserialized = CatalystSerde.deserialize[T](planWithBarrier) + val deserialized = CatalystSerde.deserialize[T](logicalPlan) sparkSession.sessionState.executePlan(deserialized) } @@ -3086,7 +3149,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = planWithBarrier, + child = logicalPlan, allowExisting = false, replace = replace, viewType = viewType) @@ -3187,7 +3250,7 @@ class Dataset[T] private[sql]( EvaluatePython.javaToPython(rdd) } - private[sql] def collectToPython(): Int = { + private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) @@ -3197,18 +3260,66 @@ class Dataset[T] private[sql]( } } + private[sql] def getRowsToPython( + _numRows: Int, + truncate: Int): Array[Any] = { + EvaluatePython.registerPicklers() + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + val rows = getRows(numRows, truncate).map(_.toArray).toArray + val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + rows.iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-GetRows") + } + /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ - private[sql] def collectAsArrowToPython(): Int = { + private[sql] def collectAsArrowToPython(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = - toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => + val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) + val arrowBatchRdd = toArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length + + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Array[Byte]]](numPartitions - 1) + var lastIndex = -1 // index of last partition written + + // Handler to eagerly write partitions to Python in order + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { + batchWriter.writeBatches(arrowBatches.iterator) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 + } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() + } + } else { + // Store partitions received out of order + results(index - 1) = arrowBatches + } + } + + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) + } } } - private[sql] def toPythonIterator(): Int = { + private[sql] def toPythonIterator(): Array[Any] = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) } @@ -3287,7 +3398,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, planWithBarrier) + Sort(sortOrder, global = global, logicalPlan) } } @@ -3311,20 +3422,20 @@ class Dataset[T] private[sql]( } } - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { + /** Convert to an RDD of serialized ArrowRecordBatches. */ + private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator( + ArrowConverters.toBatchIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } // This is only used in tests, for now. - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - toArrowPayload(queryExecution.executedPlan) + private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { + toArrowBatchRdd(queryExecution.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 86e02e98c01f3..b21c50af18433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -20,10 +20,48 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability /** - * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the - * generated data to external systems. Each partition will use a new deserialized instance, so you - * usually should do all the initialization (e.g. opening a connection or initiating a transaction) - * in the `open` method. + * The abstract class for writing custom logic to process data generated by a query. + * This is often used to write the output of a streaming query to arbitrary storage systems. + * Any implementation of this base class will be used by Spark in the following way. + * + *
        + *
      • A single instance of this class is responsible of all the data generated by a single task + * in a query. In other words, one instance is responsible for processing one partition of the + * data generated in a distributed manner. + * + *
      • Any implementation of this class must be serializable because each task will get a fresh + * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that + * any initialization for writing data (e.g. opening a connection or starting a transaction) + * is done after the `open(...)` method has been called, which signifies that the task is + * ready to generate data. + * + *
      • The lifecycle of the methods are as follows. + * + *
        + *   For each partition with `partitionId`:
        + *       For each batch/epoch of streaming data (if its streaming query) with `epochId`:
        + *           Method `open(partitionId, epochId)` is called.
        + *           If `open` returns true:
        + *                For each row in the partition and batch/epoch, method `process(row)` is called.
        + *           Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
        + *   
        + * + *
      + * + * Important points to note: + *
        + *
      • The `partitionId` and `epochId` can be used to deduplicate generated data when failures + * cause reprocessing of some input data. This depends on the execution mode of the query. If + * the streaming query is being executed in the micro-batch mode, then every partition + * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data. + * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data + * and achieve exactly-once guarantees. However, if the streaming query is being executed in the + * continuous mode, then this guarantee does not hold and therefore should not be used for + * deduplication. + * + *
      • The `close()` method will be called if `open()` method returns successfully (irrespective + * of the return value), except if the JVM crashes in the middle. + *
      * * Scala example: * {{{ @@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability * } * }); * }}} + * * @since 2.0.0 */ @InterfaceStability.Evolving @@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. The `version` is - * for data deduplication when there are failures. When recovering from a failure, some data may - * be generated multiple times but they will always have the same version. - * - * If this method finds using the `partitionId` and `version` that this partition has already been - * processed, it can return `false` to skip the further data processing. However, `close` still - * will be called for cleaning up resources. + * Called when starting to process one partition of new data in the executor. See the class + * docs for more information on how to use the `partitionId` and `epochId`. * * @param partitionId the partition id. - * @param version a unique id for data deduplication. + * @param epochId a unique id for data deduplication. * @return `true` if the corresponding partition and version id should be processed. `false` * indicates the partition should be skipped. */ - def open(partitionId: Long, version: Long): Boolean + def open(partitionId: Long, epochId: Long): Boolean /** - * Called to process the data in the executor side. This method will be called only when `open` + * Called to process the data in the executor side. This method will be called only if `open` * returns `true`. */ def process(value: T): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7147798d99533..d700fb83b9b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -62,8 +62,7 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => - Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) @@ -73,7 +72,7 @@ class RelationalGroupedDataset protected[sql]( case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) } } @@ -315,7 +314,67 @@ class RelationalGroupedDataset protected[sql]( * @param pivotColumn Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): RelationalGroupedDataset = { + def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + + /** + * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified + * aggregation. + * + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. + * + * {{{ + * // Or without specifying column values (less efficient) + * df.groupBy($"year").pivot($"course").sum($"earnings"); + * }}} + * + * @param pivotColumn he column to pivot. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column): RelationalGroupedDataset = { // This is to prevent unintended OOM errors when the number of distinct values is large val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues // Get the distinct values of the column and sort them so its consistent @@ -340,29 +399,24 @@ class RelationalGroupedDataset protected[sql]( /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") + * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") * }}} * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -372,25 +426,14 @@ class RelationalGroupedDataset protected[sql]( /** * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. - * - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} + * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of + * the `String` type. * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ - def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } @@ -452,7 +495,7 @@ class RelationalGroupedDataset protected[sql]( require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], - "The returnType of the udf must be a StructType") + s"The returnType of the udf must be a ${StructType.simpleString}") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne @@ -470,8 +513,11 @@ class RelationalGroupedDataset protected[sql]( override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") - val kFields = groupingExprs.map(_.asInstanceOf[NamedExpression]).map { - case f => s"${f.name}: ${f.dataType.simpleString(2)}" + val kFields = groupingExprs.collect { + case expr: NamedExpression if expr.resolved => + s"${expr.name}: ${expr.dataType.simpleString(2)}" + case expr: NamedExpression => expr.name + case o => o.toString } builder.append(kFields.take(2).mkString(", ")) if (kFields.length > 2) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index b352e332bc7e0..3c39579149fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -132,6 +132,17 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.unsetConf(key) } + /** + * Indicates whether the configuration property with the given key + * is modifiable in the current session. + * + * @return `true` if the configuration property is modifiable. For static SQL, Spark Core, + * invalid (not existing) and other non-modifiable configuration properties, + * the returned value is `false`. + * @since 2.4.0 + */ + def isModifiable(key: String): Boolean = sqlConf.isModifiable(key) + /** * Returns whether a particular key is set. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..2b847fb6f9458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -92,7 +92,8 @@ class SparkSession private( // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. SQLConf.setSQLConfGetter(() => { - SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) + SparkSession.getActiveSession.filterNot(_.sparkContext.isStopped).map(_.sessionState.conf) + .getOrElse(SQLConf.getFallbackConf) }) /** @@ -269,7 +270,7 @@ class SparkSession private( */ @transient lazy val emptyDataFrame: DataFrame = { - createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) + createDataFrame(sparkContext.emptyRDD[Row].setName("empty"), StructType(Nil)) } /** @@ -394,7 +395,7 @@ class SparkSession private( // BeanInfo is not serializable so we must rediscover it remotely for each partition. SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq) } - Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self)) + Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd.setName(rdd.name))(self)) } /** @@ -593,7 +594,7 @@ class SparkSession private( } else { rowRDD.map { r: Row => InternalRow.fromSeq(r.toSeq) } } - internalCreateDataFrame(catalystRows, schema) + internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema) } @@ -898,6 +899,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1020,16 +1022,34 @@ object SparkSession extends Logging { /** * Returns the active SparkSession for the current thread, returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(activeThreadSession.get) + } + } /** * Returns the default SparkSession that is returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(defaultSession.get) + } + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1082,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f94baef39dfad..24ee46d0e8147 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i] :: $s"}) println(s""" |/** | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). @@ -122,9 +122,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputTypes = Try($inputTypes).toOption + | val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try($inputTypes).toOption | def builder(e: Seq[Expression]) = if (e.length == $x) { - | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + | udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $x; Found: " + e.length) @@ -167,9 +168,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) @@ -186,9 +188,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) @@ -205,9 +208,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) @@ -224,9 +228,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) @@ -243,9 +248,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) @@ -262,9 +268,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) @@ -281,9 +288,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) @@ -300,9 +308,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) @@ -319,9 +328,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) @@ -338,9 +348,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) @@ -357,9 +368,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) @@ -376,9 +388,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) @@ -395,9 +408,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) @@ -414,9 +428,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) @@ -433,9 +448,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) @@ -452,9 +468,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) @@ -471,9 +488,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) @@ -490,9 +508,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) @@ -509,9 +528,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) @@ -528,9 +548,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) @@ -547,9 +568,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) @@ -566,9 +588,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) @@ -585,9 +608,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: ScalaReflection.schemaFor[A22] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index b33760b1edbc6..c0830e77b5a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.api.python -import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each serialized ArrowRecordBatch as a partition. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. * @param sqlContext The active [[SQLContext]]. - * @return The converted [[DataFrame]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], - schemaString: String, - sqlContext: SQLContext): DataFrame = { - ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + def arrowReadStreamFromFile( + sqlContext: SQLContext, + filename: String, + schemaString: String): DataFrame = { + val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename) + ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a8794be7280c7..c9929935fb8ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -71,7 +71,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) cachedData.clear() } @@ -105,24 +105,50 @@ class CacheManager extends Logging { } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param query The [[Dataset]] to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * [[Dataset]]; otherwise un-cache the given [[Dataset]] only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + def uncacheQuery( + query: Dataset[_], + cascade: Boolean, + blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking) } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param spark The Spark session. + * @param plan The plan to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * plan; otherwise un-cache the given plan only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + def uncacheQuery( + spark: SparkSession, + plan: LogicalPlan, + cascade: Boolean, + blocking: Boolean): Unit = writeLock { + val shouldRemove: LogicalPlan => Boolean = + if (cascade) { + _.find(_.sameResult(plan)).isDefined + } else { + _.sameResult(plan) + } val it = cachedData.iterator() while (it.hasNext) { val cd = it.next() - if (cd.plan.find(_.sameResult(plan)).isDefined) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + if (shouldRemove(cd.plan)) { + cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } + // Re-compile dependent cached queries after removing the cached query. + if (!cascade) { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined, clearCache = false) + } } /** @@ -132,22 +158,24 @@ class CacheManager extends Logging { recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) } - private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + private def recacheByCondition( + spark: SparkSession, + condition: LogicalPlan => Boolean, + clearCache: Boolean = true): Unit = { val it = cachedData.iterator() val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist() + if (clearCache) { + cd.cachedRepresentation.cacheBuilder.clearCache() + } // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( - useCompression = cd.cachedRepresentation.useCompression, - batchSize = cd.cachedRepresentation.batchSize, - storageLevel = cd.cachedRepresentation.storageLevel, - child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName, + cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fc3dbc1c5591b..48abad9078650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { - s""" + val code = code"${ctx.registerComment(str)}" + (if (nullable) { + code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { - s"$javaType $valueVar = $value;" - }).trim + code"$javaType $valueVar = $value;" + }) ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3d..36ed016773b67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +import org.apache.spark.util.collection.BitSet trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation @@ -69,7 +70,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text) + Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) } } @@ -151,6 +152,7 @@ case class RowDataSourceScanExec( * @param output Output attributes of the scan, including data attributes and partition attributes. * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. + * @param optionalBucketSet Bucket ids for bucket pruning * @param dataFilters Filters on non-partition columns. * @param tableIdentifier identifier for the table in the metastore. */ @@ -159,14 +161,17 @@ case class FileSourceScanExec( output: Seq[Attribute], requiredSchema: StructType, partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], dataFilters: Seq[Expression], override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - override val supportsBatch: Boolean = relation.fileFormat.supportBatch( + // Note that some vals referring the file-based relation are lazy intentionally + // so that this plan can be canonicalized on executor side too. See SPARK-23731. + override lazy val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - override val needsUnsafeRowConversion: Boolean = { + override lazy val needsUnsafeRowConversion: Boolean = { if (relation.fileFormat.isInstanceOf[ParquetSource]) { SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled } else { @@ -196,7 +201,7 @@ case class FileSourceScanExec( ret } - override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { relation.bucketSpec } else { @@ -267,7 +272,7 @@ case class FileSourceScanExec( private val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") - override val metadata: Map[String, String] = { + override lazy val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") val location = relation.location val locationDesc = @@ -286,7 +291,20 @@ case class FileSourceScanExec( } getOrElse { metadata } - withOptPartitionCount + + val withSelectedBucketsCount = relation.bucketSpec.map { spec => + val numSelectedBuckets = optionalBucketSet.map { b => + b.cardinality() + } getOrElse { + spec.numBuckets + } + withOptPartitionCount + ("SelectedBucketsCount" -> + s"$numSelectedBuckets out of ${spec.numBuckets}") + } getOrElse { + withOptPartitionCount + } + + withSelectedBucketsCount } private lazy val inputRDD: RDD[InternalRow] = { @@ -365,7 +383,7 @@ case class FileSourceScanExec( selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val bucketed = + val filesGroupedToBuckets = selectedPartitions.flatMap { p => p.files.map { f => val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen) @@ -377,8 +395,17 @@ case class FileSourceScanExec( .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) } + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { + f => bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil)) } new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) @@ -503,6 +530,7 @@ case class FileSourceScanExec( output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, QueryPlan.normalizePredicates(partitionFilters, output), + optionalBucketSet, QueryPlan.normalizePredicates(dataFilters, output), None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index be50a1571a2ff..2962becb64e88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -103,6 +103,10 @@ case class ExternalRDDScanExec[T]( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") + + override val nodeName: String = s"Scan$rddName" + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val outputDataType = outputObjAttr.dataType @@ -116,7 +120,7 @@ case class ExternalRDDScanExec[T]( } override def simpleString: String = { - s"Scan $nodeName${output.mkString("[", ",", "]")}" + s"$nodeName${output.mkString("[", ",", "]")}" } } @@ -169,10 +173,14 @@ case class LogicalRDD( case class RDDScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], - override val nodeName: String, + name: String, override val outputPartitioning: Partitioning = UnknownPartitioning(0), override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { + private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") + + override val nodeName: String = s"Scan $name$rddName" + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -189,6 +197,6 @@ case class RDDScanExec( } override def simpleString: String = { - s"Scan $nodeName${Utils.truncatedString(output, "[", ",", "]")}" + s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index e4812f3d338fb..5b4edf5136e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -152,7 +153,7 @@ case class ExpandExec( } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val code = s""" + val code = code""" |boolean $isNull = true; |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f40c50df74ccb..2549b9e1537a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ @@ -313,13 +314,13 @@ case class GenerateExec( if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) + ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala new file mode 100644 index 0000000000000..c88b2f8c034fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField} +import org.apache.spark.sql.types.StructField + +/** + * A Scala extractor that extracts the child expression and struct field from a [[GetStructField]]. + * This is in contrast to the [[GetStructField]] case class extractor which returns the field + * ordinal instead of the field itself. + */ +private[execution] object GetStructFieldObject { + def unapply(getStructField: GetStructField): Option[(Expression, StructField)] = + Some(( + getStructField.child, + getStructField.childSchema(getStructField.ordinal))) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 514ad7018d8c7..448eb703eacde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. + * + * `Seq` may not be serializable and ideally we should not send `rows` and `unsafeRows` + * to the executors. Thus marking them as transient. */ case class LocalTableScanExec( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index acbd4becb8549..3ca03ab2939aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} @@ -49,9 +50,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => + case a @ Aggregate(_, aggExprs, child @ PhysicalOperation( + projectList, filters, PartitionedRelation(partAttrs, rel))) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(attrs)) { + if (AttributeSet((projectList ++ filters).flatMap(_.references)).subsetOf(partAttrs)) { + // The project list and filters all only refer to partition attributes, which means the + // the Aggregator operator can also only refer to partition attributes, and filters are + // all partition filters. This is a metadata only query we can optimize. val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -102,7 +107,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic partFilters: Seq[Expression]): LogicalPlan = { // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the // relation's schema. PartitionedRelation ensures that the filters only reference partition cols - val relFilters = partFilters.map { e => + val normalizedFilters = partFilters.map { e => e transform { case a: AttributeReference => a.withName(relation.output.find(_.semanticEquals(a)).get.name) @@ -114,11 +119,8 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(relFilters, Nil) - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) + val partitionData = fsRelation.location.listFiles(normalizedFilters, Nil) + LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -127,7 +129,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) val partitions = if (partFilters.nonEmpty) { - catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + catalog.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters) } else { catalog.listPartitions(relation.tableMeta.identifier) } @@ -137,10 +139,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.toArray) + LocalRelation(partAttrs, partitionData) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -151,44 +150,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes, partition filters, and the table relation node. - * - * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with - * deterministic expressions, and returns result after reaching the partitioned table relation - * node. + * pair of the partition attributes and the table relation node. */ object PartitionedRelation extends PredicateHelper { - def unapply( - plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = { plan match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => + if fsRelation.partitionSchema.nonEmpty => val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) - Some((partAttrs, partAttrs, Nil, l)) + Some((partAttrs, l)) case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => val partAttrs = AttributeSet( getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) - Some((partAttrs, partAttrs, Nil, relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (p.references.subsetOf(attrs)) { - Some((partAttrs, p.outputSet, filters, relation)) - } else { - None - } - } - - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (f.references.subsetOf(partAttrs)) { - Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) - } else { - None - } - } + Some((partAttrs, relation)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala new file mode 100644 index 0000000000000..2236f18b0da12 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A Scala extractor that projects an expression over a given schema. Data types, + * field indexes and field counts of complex type extractors and attributes + * are adjusted to fit the schema. All other expressions are left as-is. This + * class is motivated by columnar nested schema pruning. + */ +private[execution] case class ProjectionOverSchema(schema: StructType) { + private val fieldNames = schema.fieldNames.toSet + + def unapply(expr: Expression): Option[Expression] = getProjection(expr) + + private def getProjection(expr: Expression): Option[Expression] = + expr match { + case a: AttributeReference if fieldNames.contains(a.name) => + Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) + case GetArrayItem(child, arrayItemOrdinal) => + getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) } + case a: GetArrayStructFields => + getProjection(a.child).map(p => (p, p.dataType)).map { + case (projection, ArrayType(projSchema @ StructType(_), _)) => + GetArrayStructFields(projection, + projSchema(a.field.name), + projSchema.fieldIndex(a.field.name), + projSchema.size, + a.containsNull) + } + case GetMapValue(child, key) => + getProjection(child).map { projection => GetMapValue(projection, key) } + case GetStructFieldObject(child, field: StructField) => + getProjection(child).map(p => (p, p.dataType)).map { + case (projection, projSchema: StructType) => + GetStructField(projection, projSchema.fieldIndex(field.name)) + } + case _ => + None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 15379a0663f7d..64f49e2d0d4e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( - python.ExtractPythonUDFs, PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), @@ -225,7 +224,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * Redact the sensitive information in the given string. */ private def withRedaction(message: String): String = { - Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message) + Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, message) } /** A special namespace for commands that can be used to debug query execution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e991da7df0bde..439932b0cc3ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -88,15 +90,43 @@ object SQLExecution { /** * Wrap an action with a known executionId. When running a different action in a different * thread from the original one, this method can be used to connect the Spark jobs in this action - * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`. + * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + /** + * Wrap an action with specified SQL configs. These configs will be propagated to the executor + * side via job local properties. + */ + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala new file mode 100644 index 0000000000000..0e7c593f9fb67 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A Scala extractor that builds a [[org.apache.spark.sql.types.StructField]] from a Catalyst + * complex type extractor. For example, consider a relation with the following schema: + * + * {{{ + * root + * |-- name: struct (nullable = true) + * | |-- first: string (nullable = true) + * | |-- last: string (nullable = true) + * }}} + * + * Further, suppose we take the select expression `name.first`. This will parse into an + * `Alias(child, "first")`. Ignoring the alias, `child` matches the following pattern: + * + * {{{ + * GetStructFieldObject( + * AttributeReference("name", StructType(_), _, _), + * StructField("first", StringType, _, _)) + * }}} + * + * [[SelectedField]] converts that expression into + * + * {{{ + * StructField("name", StructType(Array(StructField("first", StringType)))) + * }}} + * + * by mapping each complex type extractor to a [[org.apache.spark.sql.types.StructField]] with the + * same name as its child (or "parent" going right to left in the select expression) and a data + * type appropriate to the complex type extractor. In our example, the name of the child expression + * is "name" and its data type is a [[org.apache.spark.sql.types.StructType]] with a single string + * field named "first". + * + * @param expr the top-level complex type extractor + */ +private[execution] object SelectedField { + def unapply(expr: Expression): Option[StructField] = { + // If this expression is an alias, work on its child instead + val unaliased = expr match { + case Alias(child, _) => child + case expr => expr + } + selectField(unaliased, None) + } + + private def selectField(expr: Expression, fieldOpt: Option[StructField]): Option[StructField] = { + expr match { + // No children. Returns a StructField with the attribute name or None if fieldOpt is None. + case AttributeReference(name, dataType, nullable, metadata) => + fieldOpt.map(field => + StructField(name, wrapStructType(dataType, field), nullable, metadata)) + // Handles case "expr0.field[n]", where "expr0" is of struct type and "expr0.field" is of + // array type. + case GetArrayItem(x @ GetStructFieldObject(child, field @ StructField(name, + dataType, nullable, metadata)), _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), nullable, metadata)).getOrElse(field) + selectField(child, Some(childField)) + // Handles case "expr0.field[n]", where "expr0.field" is of array type. + case GetArrayItem(child, _) => + selectField(child, fieldOpt) + // Handles case "expr0.field.subfield", where "expr0" and "expr0.field" are of array type. + case GetArrayStructFields(child: GetArrayStructFields, + field @ StructField(name, dataType, nullable, metadata), _, _, _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field", where "expr0" is of array type. + case GetArrayStructFields(child, + field @ StructField(name, dataType, nullable, metadata), _, _, _) => + val childField = + fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of + // map type. + case GetMapValue(x @ GetStructFieldObject(child, field @ StructField(name, + dataType, + nullable, metadata)), _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field[key]", where "expr0.field" is of map type. + case GetMapValue(child, _) => + selectField(child, fieldOpt) + // Handles case "expr0.field", where expr0 is of struct type. + case GetStructFieldObject(child, + field @ StructField(name, dataType, nullable, metadata)) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + case _ => + None + } + } + + // Constructs a composition of complex types with a StructType(Array(field)) at its core. Returns + // a StructType for a StructType, an ArrayType for an ArrayType and a MapType for a MapType. + private def wrapStructType(dataType: DataType, field: StructField): DataType = { + dataType match { + case _: StructType => + StructType(Array(field)) + case ArrayType(elementType, containsNull) => + ArrayType(wrapStructType(elementType, field), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType(keyType, wrapStructType(valueType, field), valueContainsNull) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978dc..6c6d344240cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,22 +21,26 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource -import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning +import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( catalog: SessionCatalog, experimentalMethods: ExperimentalMethods) extends Optimizer(catalog) { - override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ + override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ - Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("Extract Python UDFs", Once, + Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ - Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ + Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + override def nonExcludableRules: Seq[String] = + super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName + /** * Optimization batches that are executed before the regular optimization batches (also before * the finish analysis batch). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 398758a3331b4..1f97993e20458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -47,17 +47,15 @@ import org.apache.spark.util.ThreadUtils abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { /** - * A handle to the SQL Context that was used to create this plan. Since many operators need + * A handle to the SQL Context that was used to create this plan. Since many operators need * access to the sqlContext for RDD operations or configuration this field is automatically * populated by the query planning infrastructure. */ - @transient - final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull + @transient final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when SparkPlan nodes are created without the active sessions. - // So far, this only happens in the test cases. val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.subexpressionEliminationEnabled } else { @@ -69,7 +67,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SparkSession.setActiveSession(sqlContext.sparkSession) + if (sqlContext != null) { + SparkSession.setActiveSession(sqlContext.sparkSession) + } super.makeCopy(newArgs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 74048871f8d42..2a4a1c8ef3438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,11 +36,13 @@ class SparkPlanner( override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + PythonEvals :: DataSourceV2Strategy :: FileSourceStrategy :: DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: + Window :: JoinSelection :: InMemoryScans :: BasicOperators :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 4828fa60a7b58..89cb63784c0f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1458,6 +1458,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + // ALTER VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("ALTER VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("ALTER VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } AlterViewAsCommand( name = visitTableIdentifier(ctx.tableIdentifier), originalText = source(ctx.query), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba242..dbc6db62bd820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,14 +27,16 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} import org.apache.spark.sql.types.StructType /** @@ -66,22 +68,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => - // With whole stage codegen, Spark releases resources only when all the output data of the - // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little - // data from child plan and finishes the query without releasing resources. Here we wrap - // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and - // trigger the resource releasing work, after we consume `limit` rows. - CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } @@ -323,15 +324,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) { throw new AnalysisException( "Streaming aggregation doesn't support group aggregate pandas UDF") } + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, + stateVersion, planLater(child)) case _ => Nil @@ -350,6 +354,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming global limit operator for streams in append mode. + * We need to check for either a direct Limit or a Limit wrapped in a ReturnAnswer operator, + * following the example of the SpecialLimits Strategy above. + * Streams with limit in Append mode use the stateful StreamingGlobalLimitExec. + * Streams with limit in Complete mode use the stateless CollectLimitExec operator. + * Limit is unsupported for streams in Update mode. + */ + case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + } + object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { @@ -361,7 +388,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case Join(left, right, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( - "Stream stream joins without equality predicate is not supported", plan = Some(plan)) + "Stream-stream join without equality predicate is not supported", plan = Some(plan)) case _ => Nil } @@ -380,9 +407,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } @@ -424,6 +451,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object Window extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalWindow( + WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) => + execution.window.WindowExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case PhysicalWindow( + WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) => + execution.python.WindowInPandasExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { @@ -465,16 +508,30 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case FlatMapGroupsWithState( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = FlatMapGroupsWithStateExec( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, - timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion, + outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil } } - // Can we automate these 'pass through' operations? + /** + * Strategy to convert EvalPython logical operator to physical operator. + */ + object PythonEvals extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ArrowEvalPython(udfs, output, child) => + ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil + case BatchEvalPython(udfs, output, child) => + BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil + case _ => + Nil + } + } + object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil @@ -490,12 +547,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") - case logical.Intersect(left, right) => + case logical.Intersect(left, right, false) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.Intersect(left, right, true) => throw new IllegalStateException( - "logical intersect operator should have been replaced by semi-join in the optimizer") - case logical.Except(left, right) => + "logical intersect operator should have been replaced by union, aggregate" + + " and generate operators in the optimizer") + case logical.Except(left, right, false) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") + case logical.Except(left, right, true) => + throw new IllegalStateException( + "logical except (all) operator should have been replaced by union, aggregate" + + " and generate operators in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil @@ -544,8 +609,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 828b51fa199de..1fc4de9e56015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -21,12 +21,14 @@ import java.util.Locale import java.util.function.Supplier import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -122,10 +124,10 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" + val code = code""" |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim + |${ev.code} + """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { // There are no columns @@ -259,8 +261,8 @@ trait CodegenSupport extends SparkPlan { * them to be evaluated twice. */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") - variables.foreach(_.code = "") + val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n") + variables.foreach(_.code = EmptyBlock) evaluate } @@ -274,9 +276,9 @@ trait CodegenSupport extends SparkPlan { required: AttributeSet): String = { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => - if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars.append(ev.code.trim + "\n") - ev.code = "" + if (ev.code.nonEmpty && required.contains(attributes(i))) { + evaluateVars.append(ev.code.toString + "\n") + ev.code = EmptyBlock } } evaluateVars.toString() @@ -581,7 +583,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) val (_, maxCodeSize) = try { CodeGenerator.compile(cleanedSource) } catch { - case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => + case NonFatal(_) if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ebbdf1aaa024d..6be88c463dbd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -177,6 +177,10 @@ object AggUtils { case agg @ AggregateExpression(aggregateFunction, mode, true, _) => aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] + case agg => + throw new IllegalArgumentException( + "Non-distinct aggregate is found in functionsWithDistinct " + + s"at planAggregateWithOneDistinct: $agg") } val partialDistinctAggregate: SparkPlan = { @@ -256,6 +260,7 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, child: SparkPlan): Seq[SparkPlan] = { val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -287,7 +292,8 @@ object AggUtils { child = partialAggregate) } - val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) + val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, + partialMerged1) val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) @@ -311,6 +317,7 @@ object AggUtils { stateInfo = None, outputMode = None, eventTimeWatermark = None, + stateFormatVersion = stateFormatVersion, partialMerged2) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..98adba50b2973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -190,7 +191,7 @@ case class HashAggregateExec( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) - val initVars = s""" + val initVars = code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin @@ -327,7 +328,7 @@ case class HashAggregateExec( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) @@ -578,6 +579,7 @@ case class HashAggregateExec( case _ => } } + val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit val thisPlan = ctx.addReferenceObj("plan", this) @@ -587,7 +589,7 @@ case class HashAggregateExec( val fastHashMapClassName = ctx.freshName("FastHashMap") if (isVectorizedHashMapEnabled) { val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() ctx.addInnerClass(generatedMap) // Inline mutable state since not many aggregation operations in a task @@ -597,7 +599,7 @@ case class HashAggregateExec( forceInline = true) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() ctx.addInnerClass(generatedMap) // Inline mutable state since not many aggregation operations in a task @@ -755,7 +757,10 @@ case class HashAggregateExec( } // generate hash code for key - val hashExpr = Murmur3Hash(groupingExpressions, 42) + // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions + // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, + // pick a different seed to avoid this conflict + val hashExpr = Murmur3Hash(groupingExpressions, 48) val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, @@ -770,8 +775,8 @@ case class HashAggregateExec( val findOrInsertRegularHashMap: String = s""" |// generate grouping key - |${unsafeRowKeyCode.code.trim} - |${hashEval.code.trim} + |${unsafeRowKeyCode.code} + |${hashEval.code} |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index de2d630de3fdb..e1c85823259b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ abstract class HashMapGenerator( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = - s""" + code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index d5508275c48c5..3d2443ca959a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -39,47 +39,23 @@ class RowBasedHashMapGenerator( aggregateExpressions: Seq[AggregateExpression], generatedClassName: String, groupingKeySchema: StructType, - bufferSchema: StructType) + bufferSchema: StructType, + bitMaxCapacity: Int) extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedKeySchema: String = - s"new org.apache.spark.sql.types.StructType()" + - groupingKeySchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedValueSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema) + val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema) s""" | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; | private int[] buckets; - | private int capacity = 1 << 16; + | private int capacity = 1 << $bitMaxCapacity; | private double loadFactor = 0.5; | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema - | private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema | private Object emptyVBase; | private long emptyVOff; | private int emptyVLen; @@ -90,9 +66,9 @@ class RowBasedHashMapGenerator( | org.apache.spark.memory.TaskMemoryManager taskMemoryManager, | InternalRow emptyAggregationBuffer) { | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch - | .allocate(keySchema, valueSchema, taskMemoryManager, capacity); + | .allocate($keySchema, $valueSchema, taskMemoryManager, capacity); | - | final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); + | final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema); | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); | | emptyVBase = emptyBuffer; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 9dc334c1ead3c..72505f7fac0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -166,7 +166,7 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) @@ -372,7 +372,7 @@ class TungstenAggregationIterator( } } - TaskContext.get().addTaskCompletionListener(_ => { + TaskContext.get().addTaskCompletionListener[Unit](_ => { // At the end of the task, update the task's peak memory usage. Since we destroy // the map to create the sorter, their memory usages should not overlap, so it is safe // to just use the max of the two. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index aab8cc50b9526..6d44890704f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( @@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction { s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + // aggregator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$"); } // TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 7b3580cecc60d..f9c4ecc14e6c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -47,59 +47,35 @@ class VectorizedHashMapGenerator( aggregateExpressions: Seq[AggregateExpression], generatedClassName: String, groupingKeySchema: StructType, - bufferSchema: StructType) + bufferSchema: StructType, + bitMaxCapacity: Int) extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - (groupingKeySchema ++ bufferSchema).map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedAggBufferSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val schemaStructType = new StructType((groupingKeySchema ++ bufferSchema).toArray) + val schema = ctx.addReferenceObj("schemaTerm", schemaStructType) + val aggBufferSchemaFieldsLength = bufferSchema.fields.length s""" | private ${classOf[OnHeapColumnVector].getName}[] vectors; | private ${classOf[ColumnarBatch].getName} batch; | private ${classOf[MutableColumnarRow].getName} aggBufferRow; | private int[] buckets; - | private int capacity = 1 << 16; + | private int capacity = 1 << $bitMaxCapacity; | private double loadFactor = 0.5; | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType schema = $generatedSchema - | private org.apache.spark.sql.types.StructType aggregateBufferSchema = - | $generatedAggBufferSchema | | public $generatedClassName() { - | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); + | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, $schema); | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = - | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length]; - | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { + | new ${classOf[OnHeapColumnVector].getName}[$aggBufferSchemaFieldsLength]; + | for (int i = 0; i < $aggBufferSchemaFieldsLength; i++) { | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}]; | } | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 7487564ed64da..1a48bc8398a63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,81 +17,83 @@ package org.apache.spark.sql.execution.arrow -import java.io.ByteArrayOutputStream -import java.nio.channels.Channels +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} +import java.nio.channels.{Channels, SeekableByteChannel} import scala.collection.JavaConverters._ +import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter} -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferOutputStream, Utils} /** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ -private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowBatchStreamWriter( + schema: StructType, + out: OutputStream, + timeZoneId: String) { - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(payload, allocator) - } + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + // Write the Arrow schema first, before batches + MessageSerializer.serialize(writeChannel, arrowSchema) /** - * Get the ArrowPayload as a type that can be served to Python. + * Consume iterator to write each serialized ArrowRecordBatch to the stream. */ - def asPythonSerializable: Array[Byte] = payload -} - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { + arrowBatchIter.foreach(writeChannel.write) + } /** - * Return the schema loaded from the Arrow record batch being iterated over + * End the Arrow stream, does not close output stream. */ - def schema: StructType + def end(): Unit = { + ArrowStreamWriter.writeEndOfStream(writeChannel) + } } private[sql] object ArrowConverters { /** - * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload - * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size + * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ - private[sql] def toPayloadIterator( + private[sql] def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, timeZoneId: String, - context: TaskContext): Iterator[ArrowPayload] = { + context: TaskContext): Iterator[Array[Byte]] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = - ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) val arrowWriter = ArrowWriter.create(root) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => root.close() allocator.close() } - new Iterator[ArrowPayload] { + new Iterator[Array[Byte]] { override def hasNext: Boolean = rowIter.hasNext || { root.close() @@ -99,9 +101,9 @@ private[sql] object ArrowConverters { false } - override def next(): ArrowPayload = { + override def next(): Array[Byte] = { val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + val writeChannel = new WriteChannel(Channels.newChannel(out)) Utils.tryWithSafeFinally { var rowCount = 0 @@ -111,45 +113,46 @@ private[sql] object ArrowConverters { rowCount += 1 } arrowWriter.finish() - writer.writeBatch() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() } { arrowWriter.reset() - writer.close() } - new ArrowPayload(out.toByteArray) + out.toByteArray } } } /** - * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator - * and the schema from the first batch of Arrow data read. + * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ - private[sql] def fromPayloadIterator( - payloadIter: Iterator[ArrowPayload], - context: TaskContext): ArrowRowIterator = { + private[sql] def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = { val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue) + + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val root = VectorSchemaRoot.create(arrowSchema, allocator) - new ArrowRowIterator { - private var reader: ArrowFileReader = null - private var schemaRead = StructType(Seq.empty) - private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty + new Iterator[InternalRow] { + private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty - context.addTaskCompletionListener { _ => - closeReader() + context.addTaskCompletionListener[Unit] { _ => + root.close() allocator.close() } - override def schema: StructType = schemaRead - override def hasNext: Boolean = rowIter.hasNext || { - closeReader() - if (payloadIter.hasNext) { + if (arrowBatchIter.hasNext) { rowIter = nextBatch() true } else { + root.close() allocator.close() false } @@ -157,19 +160,11 @@ private[sql] object ArrowConverters { override def next(): InternalRow = rowIter.next() - private def closeReader(): Unit = { - if (reader != null) { - reader.close() - reader = null - } - } - private def nextBatch(): Iterator[InternalRow] = { - val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable) - reader = new ArrowFileReader(in, allocator) - reader.loadNextBatch() // throws IOException - val root = reader.getVectorSchemaRoot // throws IOException - schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) + val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(arrowRecordBatch) + arrowRecordBatch.close() val columns = root.getFieldVectors.asScala.map { vector => new ArrowColumnVector(vector).asInstanceOf[ColumnVector] @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } + val in = new ByteArrayInputStream(batchBytes) + MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val rdd = payloadRDD.rdd.mapPartitions { iter => + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) + } + sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema) + } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { + Utils.tryWithResource(new FileInputStream(filename)) { fileStream => + // Create array to consume iterator so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + } + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + + // Iterate over the serialized Arrow RecordBatch messages from a stream + new Iterator[Array[Byte]] { + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { + val prevBatch = batch + batch = readNextBatch() + prevBatch + } + + // This gets the next serialized ArrowRecordBatch by reading message metadata to check if it + // is a RecordBatch message and then returning the complete serialized message which consists + // of a int32 length, serialized message metadata and a serialized RecordBatch message body + def readNextBatch(): Array[Byte] = { + val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) + if (msgMetadata == null) { + return null + } + + // Get the length of the body, which has not been read at this point + val bodyLength = msgMetadata.getMessageBodyLength.toInt + + // Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages + if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { + + // Buffer backed output large enough to hold the complete serialized message + val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) + + // Write message metadata to ByteBuffer output stream + MessageSerializer.writeMessageBuffer( + new WriteChannel(Channels.newChannel(bbout)), + msgMetadata.getMessageLength, + msgMetadata.getMessageBuffer) + + // Get a zero-copy ByteBuffer with already contains message metadata, must close first + bbout.close() + val bb = bbout.toByteBuffer + bb.position(bbout.getCount()) + + // Read message body directly into the ByteBuffer to avoid copy, return backed byte array + bb.limit(bb.capacity()) + JavaUtils.readFully(in, bb) + bb.array() + } else { + if (bodyLength > 0) { + // Skip message body if not a RecordBatch + in.position(in.position() + bodyLength) + } + + // Proceed to next message + readNextBatch() + } + } } - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] - sqlContext.internalCreateDataFrame(rdd, schema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 6ad11bda84bf6..533097ac399e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -23,6 +23,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object ArrowUtils { @@ -46,11 +47,13 @@ object ArrowUtils { case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType => if (timeZoneId == null) { - throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + throw new UnsupportedOperationException( + s"${TimestampType.catalogString} must supply timeZoneId parameter") } else { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } - case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + case _ => + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") } def fromArrowType(dt: ArrowType): DataType = dt match { @@ -120,4 +123,19 @@ object ArrowUtils { StructField(field.getName, dt, field.isNullable) }) } + + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = if (conf.pandasRespectSessionTimeZone) { + Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + } else { + Nil + } + val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) { + Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true") + } else { + Nil + } + Map(timeZoneConf ++ pandasColsByPosition: _*) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 22b63513548fe..8dd484af6e908 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ -import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -62,13 +61,13 @@ object ArrowWriter { case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) - case (StructType(_), vector: NullableMapVector) => + case (StructType(_), vector: StructVector) => val children = (0 until vector.size()).map { ordinal => createFieldWriter(vector.getChildByOrdinal(ordinal)) } new StructWriter(vector, children.toArray) case (dt, _) => - throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") } } } @@ -129,12 +128,7 @@ private[arrow] abstract class ArrowFieldWriter { } def reset(): Unit = { - // TODO: reset() should be in a common interface - valueVector match { - case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() - case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() - case _ => - } + valueVector.reset() count = 0 } } @@ -315,7 +309,7 @@ private[arrow] class ArrayWriter( } private[arrow] class StructWriter( - val valueVector: NullableMapVector, + val valueVector: StructVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..9434ceb7cd16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -345,6 +345,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override val output: Seq[Attribute] = range.output + override def outputOrdering: Seq[SortOrder] = range.outputOrdering + + override def outputPartitioning: Partitioning = { + if (numElements > 0) { + if (numSlices == 1) { + SinglePartition + } else { + RangePartitioning(outputOrdering, numSlices) + } + } else { + UnknownPartitioning(0) + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -629,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index e9b150fd86095..542a10fc175c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -717,7 +717,7 @@ private[columnar] object ColumnType { case struct: StructType => STRUCT(struct) case udt: UserDefinedType[_] => apply(udt.sqlType) case other => - throw new Exception(s"Unsupported type: ${other.simpleString}") + throw new Exception(s"Unsupported type: ${other.catalogString}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a7ba9b86a176f..1a8fbaca53f59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -29,20 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.LongAccumulator - - -object InMemoryRelation { - def apply( - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - child: SparkPlan, - tableName: Option[String], - logicalPlan: LogicalPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) -} +import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -55,58 +42,51 @@ object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) -case class InMemoryRelation( - output: Seq[Attribute], +case class CachedRDDBuilder( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - @transient child: SparkPlan, + @transient cachedPlan: SparkPlan, tableName: Option[String])( - @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics, - override val outputOrdering: Seq[SortOrder]) - extends logical.LeafNode with MultiInstanceRelation { - - override protected def innerChildren: Seq[SparkPlan] = Seq(child) + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) { - override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), - storageLevel = StorageLevel.NONE, - child = child.canonicalized, - tableName = None)( - _cachedColumnBuffers, - sizeInBytesStats, - statsOfPlanToCache, - outputOrdering) - - override def producedAttributes: AttributeSet = outputSet + val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator - @transient val partitionStatistics = new PartitionStatistics(output) + def cachedColumnBuffers: RDD[CachedBatch] = { + if (_cachedColumnBuffers == null) { + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() + } + } + } + _cachedColumnBuffers + } - override def computeStats(): Statistics = { - if (sizeInBytesStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - // Note that we should drop the hint info here. We may cache a plan whose root node is a hint - // node. When we lookup the cache with a semantically same plan without hint info, the plan - // returned by cache lookup should not have hint info. If we lookup the cache with a - // semantically same plan with a different hint info, `CacheManager.useCachedData` will take - // care of it and retain the hint info in the lookup input plan. - statsOfPlanToCache.copy(hints = HintInfo()) - } else { - Statistics(sizeInBytes = sizeInBytesStats.value.longValue) + def clearCache(blocking: Boolean = true): Unit = { + if (_cachedColumnBuffers != null) { + synchronized { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } + } } } - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() + def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = { + new CachedRDDBuilder( + useCompression, + batchSize, + storageLevel, + cachedPlan = cachedPlan, + tableName + )(_cachedColumnBuffers) } - private def buildBuffers(): Unit = { - val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => + private def buildBuffers(): RDD[CachedBatch] = { + val output = cachedPlan.output + val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -154,32 +134,80 @@ case class InMemoryRelation( cached.setName( tableName.map(n => s"In-memory table $n") - .getOrElse(StringUtils.abbreviate(child.toString, 1024))) - _cachedColumnBuffers = cached + .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))) + cached + } +} + +object InMemoryRelation { + + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + logicalPlan: LogicalPlan): InMemoryRelation = { + val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } + + def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { + new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } +} + +case class InMemoryRelation( + output: Seq[Attribute], + @transient cacheBuilder: CachedRDDBuilder)( + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) + extends logical.LeafNode with MultiInstanceRelation { + + override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) + + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)), + cacheBuilder)( + statsOfPlanToCache, + outputOrdering) + + override def producedAttributes: AttributeSet = outputSet + + @transient val partitionStatistics = new PartitionStatistics(output) + + def cachedPlan: SparkPlan = cacheBuilder.cachedPlan + + override def computeStats(): Statistics = { + if (cacheBuilder.sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) + } else { + Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) + } } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { - InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) + InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), - useCompression, - batchSize, - storageLevel, - child, - tableName)( - _cachedColumnBuffers, - sizeInBytesStats, + cacheBuilder)( statsOfPlanToCache, outputOrdering).asInstanceOf[this.type] } - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers + override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) - override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + override def simpleString: String = + s"InMemoryRelation [${Utils.truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e73e1378d52e3..196d057c2de1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ @@ -78,10 +78,12 @@ case class InMemoryTableScanExec( private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) - private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + private def createAndDecompressColumn( + cachedColumnarBatch: CachedBatch, + offHeapColumnVectorEnabled: Boolean): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows val taskContext = Option(TaskContext.get()) - val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) { OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) @@ -95,16 +97,19 @@ case class InMemoryTableScanExec( columnarBatch.column(i).asInstanceOf[WritableColumnVector], columnarBatchSchema.fields(i).dataType, rowCount) } - taskContext.foreach(_.addTaskCompletionListener(_ => columnarBatch.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => columnarBatch.close())) columnarBatch } private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() + val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled if (supportsBatch) { // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. - buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + buffers + .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) + .asInstanceOf[RDD[InternalRow]] } else { val numOutputRows = longMetric("numOutputRows") @@ -154,7 +159,7 @@ case class InMemoryTableScanExec( private def updateAttribute(expr: Expression): Expression = { // attributes can be pruned so using relation's output. // E.g., relation.output is [id, item] but this scan's output can be [item] only. - val attrMap = AttributeMap(relation.child.output.zip(relation.output)) + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } @@ -163,21 +168,33 @@ case class InMemoryTableScanExec( // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - relation.child.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.child.outputPartitioning + relation.cachedPlan.outputPartitioning match { + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other } } // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputOrdering: Seq[SortOrder] = - relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. private val stats = relation.partitionStatistics private def statsFor(a: Attribute) = stats.forAttribute(a) + // Currently, only use statistics from atomic types except binary type only. + private object ExtractableLiteral { + def unapply(expr: Expression): Option[Literal] = expr match { + case lit: Literal => lit.dataType match { + case BinaryType => None + case _: AtomicType => Some(lit) + case _ => None + } + case _ => None + } + } + // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { @@ -189,33 +206,37 @@ case class InMemoryTableScanExec( if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => buildFilter(lhs) || buildFilter(rhs) - case EqualTo(a: AttributeReference, l: Literal) => + case EqualTo(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualTo(l: Literal, a: AttributeReference) => + case EqualTo(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualNullSafe(a: AttributeReference, l: Literal) => + case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualNullSafe(l: Literal, a: AttributeReference) => + case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l - case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound + case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l + case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound - case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l - case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound + case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound <= l + case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + l <= statsFor(a).upperBound - case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound - case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l + case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound + case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l - case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound - case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l + case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + l <= statsFor(a).upperBound + case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + statsFor(a).lowerBound <= l case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 case In(a: AttributeReference, list: Seq[Expression]) - if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => + if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } @@ -252,7 +273,7 @@ case class InMemoryTableScanExec( // within the map Partitions closure. val schema = stats.schema val schemaIndex = schema.zipWithIndex - val buffers = relation.cachedColumnBuffers + val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 640e01336aa75..3fea6d7c7fbfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -47,7 +47,7 @@ case class AnalyzeColumnCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val sizeInBytes = CommandUtils.calculateTotalSize(sessionState, tableMeta) + val sizeInBytes = CommandUtils.calculateTotalSize(sparkSession, tableMeta) // Compute stats for each column val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 5b54b2270b5ec..18fefa0a6f19f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Column, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} import org.apache.spark.sql.execution.datasources.PartitioningUtils @@ -140,7 +140,13 @@ case class AnalyzePartitionCommand( val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() df.collect().map { r => - val partitionColumnValues = partitionColumns.indices.map(r.get(_).toString) + val partitionColumnValues = partitionColumns.indices.map { i => + if (r.isNullAt(i)) { + ExternalCatalogUtils.DEFAULT_PARTITION_NAME + } else { + r.get(i).toString + } + } val spec = tableMeta.partitionColumnNames.zip(partitionColumnValues).toMap val count = BigInt(r.getLong(partitionColumns.size)) (spec, count) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 58b53e8b1c551..3076e919dd61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -39,7 +39,7 @@ case class AnalyzeTableCommand( } // Compute stats for the whole table - val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) + val newTotalSize = CommandUtils.calculateTotalSize(sparkSession, tableMeta) val newRowCount = if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index c27048626c8eb..df71bc9effb3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -21,12 +21,13 @@ import java.net.URI import scala.util.control.NonFatal -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition} +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex} import org.apache.spark.sql.internal.SessionState @@ -38,7 +39,7 @@ object CommandUtils extends Logging { val catalog = sparkSession.sessionState.catalog if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { val newTable = catalog.getTableMetadata(table.identifier) - val newSize = CommandUtils.calculateTotalSize(sparkSession.sessionState, newTable) + val newSize = CommandUtils.calculateTotalSize(sparkSession, newTable) val newStats = CatalogStatistics(sizeInBytes = newSize) catalog.alterTableStats(table.identifier, Some(newStats)) } else { @@ -47,15 +48,29 @@ object CommandUtils extends Logging { } } - def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): BigInt = { + def calculateTotalSize(spark: SparkSession, catalogTable: CatalogTable): BigInt = { + val sessionState = spark.sessionState if (catalogTable.partitionColumnNames.isEmpty) { calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) } else { // Calculate table size as a sum of the visible partitions. See SPARK-21079 val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) - partitions.map { p => - calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) - }.sum + if (spark.sessionState.conf.parallelFileListingInStatsComputation) { + val paths = partitions.map(x => new Path(x.storage.locationUri.get)) + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + val pathFilter = new PathFilter with Serializable { + override def accept(path: Path): Boolean = { + DataSourceUtils.isDataPath(path) && !path.getName.startsWith(stagingDir) + } + } + val fileStatusSeq = InMemoryFileIndex.bulkListLeafFiles( + paths, sessionState.newHadoopConf(), pathFilter, spark) + fileStatusSeq.flatMap(_._2.map(_.getLen)).sum + } else { + partitions.map { p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + }.sum + } } } @@ -78,7 +93,8 @@ object CommandUtils extends Logging { val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { + if (!status.getPath.getName.startsWith(stagingDir) && + DataSourceUtils.isDataPath(path)) { getPathSize(fs, status.getPath) } else { 0L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bf4d96fa18d0d..e1faecedd20ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -189,8 +189,9 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val isTempView = catalog.isTemporaryTable(tableName) - if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) { + if (!isTempView && catalog.tableExists(tableName)) { // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view // issue an exception. catalog.getTableMetadata(tableName).tableType match { @@ -204,9 +205,10 @@ case class DropTableCommand( } } - if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) { + if (isTempView || catalog.tableExists(tableName)) { try { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName), cascade = !isTempView) } catch { case NonFatal(e) => log.warn(e.toString, e) } @@ -890,7 +892,8 @@ object DDLUtils { */ def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = { val inputPaths = query.collect { - case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths + case LogicalRelation(r: HadoopFsRelation, _, _, _) => + r.location.rootPaths }.flatten if (inputPaths.contains(outputPath)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 44749190c79eb..2eca1c40a5b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.command import java.io.File -import java.net.URI +import java.net.{URI, URISyntaxException} import java.nio.file.FileSystems import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.NonFatal -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileContext, FsConstants, Path} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.Histogram -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -303,94 +303,44 @@ case class LoadDataCommand( s"partitioned, but a partition spec was provided.") } } - - val loadPath = + val loadPath = { if (isLocal) { - val uri = Utils.resolveURI(path) - val file = new File(uri.getPath) - val exists = if (file.getAbsolutePath.contains("*")) { - val fileSystem = FileSystems.getDefault - val dir = file.getParentFile.getAbsolutePath - if (dir.contains("*")) { - throw new AnalysisException( - s"LOAD DATA input path allows only filename wildcard: $path") - } - - // Note that special characters such as "*" on Windows are not allowed as a path. - // Calling `WindowsFileSystem.getPath` throws an exception if there are in the path. - val dirPath = fileSystem.getPath(dir) - val pathPattern = new File(dirPath.toAbsolutePath.toString, file.getName).toURI.getPath - val safePathPattern = if (Utils.isWindows) { - // On Windows, the pattern should not start with slashes for absolute file paths. - pathPattern.stripPrefix("/") - } else { - pathPattern - } - val files = new File(dir).listFiles() - if (files == null) { - false - } else { - val matcher = fileSystem.getPathMatcher("glob:" + safePathPattern) - files.exists(f => matcher.matches(fileSystem.getPath(f.getAbsolutePath))) - } - } else { - new File(file.getAbsolutePath).exists() - } - if (!exists) { - throw new AnalysisException(s"LOAD DATA input path does not exist: $path") - } - uri + val localFS = FileContext.getLocalFSFileContext() + makeQualified(FsConstants.LOCAL_FS_URI, localFS.getWorkingDirectory(), new Path(path)) } else { - val uri = new URI(path) - val hdfsUri = if (uri.getScheme() != null && uri.getAuthority() != null) { - uri - } else { - // Follow Hive's behavior: - // If no schema or authority is provided with non-local inpath, - // we will use hadoop configuration "fs.defaultFS". - val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.defaultFS") - val defaultFS = if (defaultFSConf == null) { - new URI("") - } else { - new URI(defaultFSConf) - } - - val scheme = if (uri.getScheme() != null) { - uri.getScheme() - } else { - defaultFS.getScheme() - } - val authority = if (uri.getAuthority() != null) { - uri.getAuthority() - } else { - defaultFS.getAuthority() - } - - if (scheme == null) { - throw new AnalysisException( - s"LOAD DATA: URI scheme is required for non-local input paths: '$path'") - } - - // Follow Hive's behavior: - // If LOCAL is not specified, and the path is relative, - // then the path is interpreted relative to "/user/" - val uriPath = uri.getPath() - val absolutePath = if (uriPath != null && uriPath.startsWith("/")) { - uriPath - } else { - s"/user/${System.getProperty("user.name")}/$uriPath" - } - new URI(scheme, authority, absolutePath, uri.getQuery(), uri.getFragment()) - } - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val srcPath = new Path(hdfsUri) - val fs = srcPath.getFileSystem(hadoopConf) - if (!fs.exists(srcPath)) { - throw new AnalysisException(s"LOAD DATA input path does not exist: $path") - } - hdfsUri + val loadPath = new Path(path) + // Follow Hive's behavior: + // If no schema or authority is provided with non-local inpath, + // we will use hadoop configuration "fs.defaultFS". + val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.defaultFS") + val defaultFS = if (defaultFSConf == null) new URI("") else new URI(defaultFSConf) + // Follow Hive's behavior: + // If LOCAL is not specified, and the path is relative, + // then the path is interpreted relative to "/user/" + val uriPath = new Path(s"/user/${System.getProperty("user.name")}/") + // makeQualified() will ignore the query parameter part while creating a path, so the + // entire string will be considered while making a Path instance,this is mainly done + // by considering the wild card scenario in mind.as per old logic query param is + // been considered while creating URI instance and if path contains wild card char '?' + // the remaining charecters after '?' will be removed while forming URI instance + makeQualified(defaultFS, uriPath, loadPath) } - + } + val fs = loadPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + // This handling is because while resolving the invalid URLs starting with file:/// + // system throws IllegalArgumentException from globStatus API,so in order to handle + // such scenarios this code is added in try catch block and after catching the + // runtime exception a generic error will be displayed to the user. + try { + val fileStatus = fs.globStatus(loadPath) + if (fileStatus == null || fileStatus.isEmpty) { + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") + } + } catch { + case e: IllegalArgumentException => + log.warn(s"Exception while validating the load path $path ", e) + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") + } if (partition.nonEmpty) { catalog.loadPartition( targetTable.identifier, @@ -413,6 +363,36 @@ case class LoadDataCommand( CommandUtils.updateTableStats(sparkSession, targetTable) Seq.empty[Row] } + + /** + * Returns a qualified path object. Method ported from org.apache.hadoop.fs.Path class. + * + * @param defaultUri default uri corresponding to the filesystem provided. + * @param workingDir the working directory for the particular child path wd-relative names. + * @param path Path instance based on the path string specified by the user. + * @return qualified path object + */ + private def makeQualified(defaultUri: URI, workingDir: Path, path: Path): Path = { + val pathUri = if (path.isAbsolute()) path.toUri() else new Path(workingDir, path).toUri() + if (pathUri.getScheme == null || pathUri.getAuthority == null && + defaultUri.getAuthority != null) { + val scheme = if (pathUri.getScheme == null) defaultUri.getScheme else pathUri.getScheme + val authority = if (pathUri.getAuthority == null) { + if (defaultUri.getAuthority == null) "" else defaultUri.getAuthority + } else { + pathUri.getAuthority + } + try { + val newUri = new URI(scheme, authority, pathUri.getPath, pathUri.getFragment) + new Path(newUri) + } catch { + case e: URISyntaxException => + throw new IllegalArgumentException(e) + } + } else { + path + } + } } /** @@ -493,7 +473,7 @@ case class TruncateTableCommand( spark.sessionState.refreshTable(tableName.unquotedString) // Also try to drop the contents of the table from the columnar cache try { - spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier)) + spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), cascade = true) } catch { case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) @@ -960,6 +940,9 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman case EXTERNAL => " EXTERNAL TABLE" case VIEW => " VIEW" case MANAGED => " TABLE" + case t => + throw new IllegalArgumentException( + s"Unknown table type is found at showCreateHiveTable: $t") } builder ++= s"CREATE$tableTypeString ${table.quotedString}" @@ -982,7 +965,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showHiveTableHeader(metadata: CatalogTable, builder: StringBuilder): Unit = { val columns = metadata.schema.filterNot { column => metadata.partitionColumnNames.contains(column.name) - }.map(columnToDDLFragment) + }.map(_.toDDL) if (columns.nonEmpty) { builder ++= columns.mkString("(", ", ", ")\n") @@ -994,14 +977,10 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman .foreach(builder.append) } - private def columnToDDLFragment(column: StructField): String = { - val comment = column.getComment().map(escapeSingleQuotedString).map(" COMMENT '" + _ + "'") - s"${quoteIdentifier(column.name)} ${column.dataType.catalogString}${comment.getOrElse("")}" - } private def showHiveTableNonDataColumns(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.partitionColumnNames.nonEmpty) { - val partCols = metadata.partitionSchema.map(columnToDDLFragment) + val partCols = metadata.partitionSchema.map(_.toDDL) builder ++= partCols.mkString("PARTITIONED BY (", ", ", ")\n") } @@ -1072,7 +1051,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showDataSourceTableDataColumns( metadata: CatalogTable, builder: StringBuilder): Unit = { - val columns = metadata.schema.fields.map(f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}") + val columns = metadata.schema.fields.map(_.toDDL) builder ++= columns.mkString("(", ", ", ")\n") } @@ -1117,15 +1096,4 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman } } } - - private def escapeSingleQuotedString(str: String): String = { - val builder = StringBuilder.newBuilder - - str.foreach { - case '\'' => builder ++= s"\\\'" - case ch => builder += ch - } - - builder.toString() - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 69c03d862391e..ba7d2b7cbdb1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration /** - * Simple metrics collected during an instance of [[FileFormatWriter.ExecuteWriteTask]]. + * Simple metrics collected during an instance of [[FileFormatDataWriter]]. * These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703). */ case class BasicWriteTaskStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index ea4fe9c8ade5f..a776fc3e7021d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning + object BucketingUtils { // The file name of bucketed data should have 3 parts: // 1. some other information in the head of file name @@ -35,5 +38,16 @@ object BucketingUtils { case other => None } + // Given bucketColumn, numBuckets and value, returns the corresponding bucketId + def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) + mutableInternalRow.update(0, value) + + val bucketIdGenerator = UnsafeProjection.create( + HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil, + bucketColumn :: Nil) + bucketIdGenerator(mutableInternalRow).getInt(0) + } + def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index c0df6c779d7bd..9fddfad249e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -50,7 +50,7 @@ object CodecStreams { */ def createInputStreamWithCloseResource(config: Configuration, path: Path): InputStream = { val inputStream = createInputStream(config, path) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => inputStream.close())) inputStream } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e77..1dcf9f3185de9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -396,6 +396,7 @@ case class DataSource( hs.partitionSchema.map(_.name), "in the partition schema", equality) + DataSourceUtils.verifyReadSchema(hs.fileFormat, hs.dataSchema) case _ => SchemaUtils.checkColumnNameDuplication( relation.schema.map(_.name), @@ -613,6 +614,8 @@ object DataSource extends Logging { case name if name.equalsIgnoreCase("orc") && conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" => "org.apache.spark.sql.hive.orc.OrcFileFormat" + case "com.databricks.spark.avro" if conf.replaceDatabricksSparkAvroEnabled => + "org.apache.spark.sql.avro.AvroFileFormat" case name => name } val provider2 = s"$provider1.DefaultSource" @@ -635,11 +638,17 @@ object DataSource extends Logging { "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " + "'native'") } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || - provider1 == "com.databricks.spark.avro") { + provider1 == "com.databricks.spark.avro" || + provider1 == "org.apache.spark.sql.avro") { throw new AnalysisException( - s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " + - "Please find an Avro package at " + - "http://spark.apache.org/third-party-projects.html") + s"Failed to find data source: $provider1. Avro is built-in but external data " + + "source module since Spark 2.4. Please deploy the application as per " + + "the deployment section of \"Apache Avro Data Source Guide\".") + } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") { + throw new AnalysisException( + s"Failed to find data source: $provider1. Please deploy the application as " + + "per the deployment section of " + + "\"Structured Streaming + Kafka Integration Guide\".") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3f41612c08065..6b61e749e3063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -131,7 +131,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast projectList } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -252,7 +252,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] table.partitionSchema.asNullable.toAttributes) } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) @@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with case _ => Nil } - // Get the bucket ID based on the bucketing values. - // Restriction: Bucket pruning works iff the bucketing column has one and only one column. - def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { - val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) - val bucketIdGeneration = UnsafeProjection.create( - HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, - bucketColumn :: Nil) - - bucketIdGeneration(mutableRow).getInt(0) - } - // Based on Public API. private def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala new file mode 100644 index 0000000000000..90cec5e72c1a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types._ + + +object DataSourceUtils { + + /** + * Verify if the schema is supported in datasource in write path. + */ + def verifyWriteSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = false) + } + + /** + * Verify if the schema is supported in datasource in read path. + */ + def verifyReadSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = true) + } + + /** + * Verify if the schema is supported in datasource. This verification should be done + * in a driver side. + */ + private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { + schema.foreach { field => + if (!format.supportDataType(field.dataType, isReadPath)) { + throw new AnalysisException( + s"$format data source does not support ${field.dataType.catalogString} data type.") + } + } + } + + // SPARK-24626: Metadata files and temporary files should not be + // counted as data files, so that they shouldn't participate in tasks like + // location size calculation. + private[sql] def isDataPath(path: Path): Boolean = { + val name = path.getName + !(name.startsWith("_") || name.startsWith(".")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 43591a9ff524a..90e81661bae7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String @@ -28,7 +29,8 @@ class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], mode: ParseMode, schema: StructType, - columnNameOfCorruptRecord: String) { + columnNameOfCorruptRecord: String, + isMultiLine: Boolean) { private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) @@ -56,9 +58,15 @@ class FailureSafeParser[IN]( } } + private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty + def parse(input: IN): Iterator[InternalRow] = { try { - rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + if (skipParsing) { + Iterator.single(InternalRow.empty) + } else { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } } catch { case e: BadRecordException => mode match { case PermissiveMode => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 023e127888290..2c162e23644ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** @@ -57,7 +57,7 @@ trait FileFormat { dataSchema: StructType): OutputWriterFactory /** - * Returns whether this format support returning columnar batch or not. + * Returns whether this format supports returning columnar batch or not. * * TODO: we should just have different traits for the different formats. */ @@ -152,6 +152,11 @@ trait FileFormat { } } + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala new file mode 100644 index 0000000000000..6499328e89ce7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.SerializableConfiguration + +/** + * Abstract class for writing out data in a single Spark task. + * Exceptions thrown by the implementation of this trait will automatically trigger task aborts. + */ +abstract class FileFormatDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) { + /** + * Max number of files a single task writes out due to file size. In most cases the number of + * files written should be very small. This is just a safe guard to protect some really bad + * settings, e.g. maxRecordsPerFile = 1. + */ + protected val MAX_FILE_COUNTER: Int = 1000 * 1000 + protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() + protected var currentWriter: OutputWriter = _ + + /** Trackers for computing various statistics on the data as it's being written out. */ + protected val statsTrackers: Seq[WriteTaskStatsTracker] = + description.statsTrackers.map(_.newTaskInstance()) + + protected def releaseResources(): Unit = { + if (currentWriter != null) { + try { + currentWriter.close() + } finally { + currentWriter = null + } + } + } + + /** Writes a record */ + def write(record: InternalRow): Unit + + /** + * Returns the summary of relative information which + * includes the list of partition strings written out. The list of partitions is sent back + * to the driver and used to update the catalog. Other information will be sent back to the + * driver too and used to e.g. update the metrics in UI. + */ + def commit(): WriteTaskResult = { + releaseResources() + val summary = ExecutedWriteSummary( + updatedPartitions = updatedPartitions.toSet, + stats = statsTrackers.map(_.getFinalStats())) + WriteTaskResult(committer.commitTask(taskAttemptContext), summary) + } + + def abort(): Unit = { + try { + releaseResources() + } finally { + committer.abortTask(taskAttemptContext) + } + } +} + +/** FileFormatWriteTask for empty partitions */ +class EmptyDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol +) extends FileFormatDataWriter(description, taskAttemptContext, committer) { + override def write(record: InternalRow): Unit = {} +} + +/** Writes data to a single directory (used for non-dynamic-partition writes). */ +class SingleDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + // Initialize currentWriter and statsTrackers + newOutputWriter() + + private def newOutputWriter(): Unit = { + recordsInFile = 0 + releaseResources() + + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) + val currentPath = committer.newTaskTempFile( + taskAttemptContext, + None, + f"-c$fileCounter%03d" + ext) + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter() + } + + currentWriter.write(record) + statsTrackers.foreach(_.newRow(record)) + recordsInFile += 1 + } +} + +/** + * Writes data to using dynamic partition writes, meaning this single function can write to + * multiple directories (partitions) or files (bucketing). + */ +class DynamicPartitionDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + + /** Flag saying whether or not the data to be written out is partitioned. */ + private val isPartitioned = description.partitionColumns.nonEmpty + + /** Flag saying whether or not the data to be written out is bucketed. */ + private val isBucketed = description.bucketIdExpression.isDefined + + assert(isPartitioned || isBucketed, + s"""DynamicPartitionWriteTask should be used for writing out data that's either + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description + """.stripMargin) + + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + private var currentPartionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + + /** Extracts the partition values out of an input row. */ + private lazy val getPartitionValues: InternalRow => UnsafeRow = { + val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) + row => proj(row) + } + + /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ + private lazy val partitionPathExpression: Expression = Concat( + description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, + StringType, + Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) + }) + + /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. */ + private lazy val getPartitionPath: InternalRow => String = { + val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) + row => proj(row).getString(0) + } + + /** Given an input row, returns the corresponding `bucketId` */ + private lazy val getBucketId: InternalRow => Int = { + val proj = + UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) + row => proj(row).getInt(0) + } + + /** Returns the data columns to be written given an input row */ + private val getOutputRow = + UnsafeProjection.create(description.dataColumns, description.allColumns) + + /** + * Opens a new OutputWriter given a partition key and/or a bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { + recordsInFile = 0 + releaseResources() + + val partDir = partitionValues.map(getPartitionPath(_)) + partDir.foreach(updatedPartitions.add) + + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + + // This must be in a form that matches our bucketing format. See BucketingUtils. + val ext = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + } + val currentPath = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartionValues != nextPartitionValues) { + currentPartionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + statsTrackers.foreach(_.newBucket(currentBucketId.get)) + } + + fileCounter = 0 + newOutputWriter(currentPartionValues, currentBucketId) + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter(currentPartionValues, currentBucketId) + } + val outputRow = getOutputRow(record) + currentWriter.write(outputRow) + statsTrackers.foreach(_.newRow(outputRow)) + recordsInFile += 1 + } +} + +/** A shared job description for all the write tasks. */ +class WriteJobDescription( + val uuid: String, // prevent collision between different (appending) write jobs + val serializableHadoopConf: SerializableConfiguration, + val outputWriterFactory: OutputWriterFactory, + val allColumns: Seq[Attribute], + val dataColumns: Seq[Attribute], + val partitionColumns: Seq[Attribute], + val bucketIdExpression: Option[Expression], + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String], + val maxRecordsPerFile: Long, + val timeZoneId: String, + val statsTrackers: Seq[WriteJobStatsTracker]) + extends Serializable { + + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), + s""" + |All columns: ${allColumns.mkString(", ")} + |Partition columns: ${partitionColumns.mkString(", ")} + |Data columns: ${dataColumns.mkString(", ")} + """.stripMargin) +} + +/** The result of a successful write task. */ +case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + +/** + * Wrapper class for the metrics of writing data out. + * + * @param updatedPartitions the partitions updated during writing data out. Only valid + * for dynamic partition. + * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. + */ +case class ExecutedWriteSummary( + updatedPartitions: Set[String], + stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 401597f967218..7c6ab4bc922fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.mutable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -30,62 +28,25 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { - - /** - * Max number of files a single task writes out due to file size. In most cases the number of - * files written should be very small. This is just a safe guard to protect some really bad - * settings, e.g. maxRecordsPerFile = 1. - */ - private val MAX_FILE_COUNTER = 1000 * 1000 - /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( - outputPath: String, - customPartitionLocations: Map[TablePartitionSpec, String], - outputColumns: Seq[Attribute]) - - /** A shared job description for all the write tasks. */ - private class WriteJobDescription( - val uuid: String, // prevent collision between different (appending) write jobs - val serializableHadoopConf: SerializableConfiguration, - val outputWriterFactory: OutputWriterFactory, - val allColumns: Seq[Attribute], - val dataColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], - val bucketIdExpression: Option[Expression], - val path: String, - val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long, - val timeZoneId: String, - val statsTrackers: Seq[WriteJobStatsTracker]) - extends Serializable { - - assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), - s""" - |All columns: ${allColumns.mkString(", ")} - |Partition columns: ${partitionColumns.mkString(", ")} - |Data columns: ${dataColumns.mkString(", ")} - """.stripMargin) - } - - /** The result of a successful write task. */ - private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + outputPath: String, + customPartitionLocations: Map[TablePartitionSpec, String], + outputColumns: Seq[Attribute]) /** * Basic work flow of this command is: @@ -135,9 +96,11 @@ object FileFormatWriter extends Logging { val caseInsensitiveOptions = CaseInsensitiveMap(options) + val dataSchema = dataColumns.toStructType + DataSourceUtils.verifyWriteSchema(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, @@ -262,30 +225,27 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) - val writeTask = + val dataWriter = if (sparkPartitionId != 0 && !iterator.hasNext) { // In case of empty job, leave first partition to save meta for file format like parquet. - new EmptyDirectoryWriteTask(description) + new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { - new SingleDirectoryWriteTask(description, taskAttemptContext, committer) + new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionWriteTask(description, taskAttemptContext, committer) + new DynamicPartitionDataWriter(description, taskAttemptContext, committer) } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - val summary = writeTask.execute(iterator) - writeTask.releaseResources() - WriteTaskResult(committer.commitTask(taskAttemptContext), summary) - })(catchBlock = { - // If there is an error, release resource and then abort the task - try { - writeTask.releaseResources() - } finally { - committer.abortTask(taskAttemptContext) - logError(s"Job $jobId aborted.") + while (iterator.hasNext) { + dataWriter.write(iterator.next()) } + dataWriter.commit() + })(catchBlock = { + // If there is an error, abort the task + dataWriter.abort() + logError(s"Job $jobId aborted.") }) } catch { case e: FetchFailedException => @@ -302,7 +262,7 @@ object FileFormatWriter extends Logging { private def processStats( statsTrackers: Seq[WriteJobStatsTracker], statsPerTask: Seq[Seq[WriteTaskStats]]) - : Unit = { + : Unit = { val numStatsTrackers = statsTrackers.length assert(statsPerTask.forall(_.length == numStatsTrackers), @@ -321,281 +281,4 @@ object FileFormatWriter extends Logging { case (statsTracker, stats) => statsTracker.processStats(stats) } } - - /** - * A simple trait for writing out data in a single Spark task, without any concerns about how - * to commit or abort tasks. Exceptions thrown by the implementation of this trait will - * automatically trigger task aborts. - */ - private trait ExecuteWriteTask { - - /** - * Writes data out to files, and then returns the summary of relative information which - * includes the list of partition strings written out. The list of partitions is sent back - * to the driver and used to update the catalog. Other information will be sent back to the - * driver too and used to e.g. update the metrics in UI. - */ - def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary - def releaseResources(): Unit - } - - /** ExecuteWriteTask for empty partitions */ - private class EmptyDirectoryWriteTask(description: WriteJobDescription) - extends ExecuteWriteTask { - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = {} - } - - /** Writes data to a single directory (used for non-dynamic-partition writes). */ - private class SingleDirectoryWriteTask( - description: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - private[this] var currentWriter: OutputWriter = _ - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - private def newOutputWriter(fileCounter: Int): Unit = { - val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - val currentPath = committer.newTaskTempFile( - taskAttemptContext, - None, - f"-c$fileCounter%03d" + ext) - - currentWriter = description.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = description.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.map(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - var fileCounter = 0 - var recordsInFile: Long = 0L - newOutputWriter(fileCounter) - - while (iter.hasNext) { - if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - recordsInFile = 0 - releaseResources() - newOutputWriter(fileCounter) - } - - val internalRow = iter.next() - currentWriter.write(internalRow) - statsTrackers.foreach(_.newRow(internalRow)) - recordsInFile += 1 - } - releaseResources() - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } - - /** - * Writes data to using dynamic partition writes, meaning this single function can write to - * multiple directories (partitions) or files (bucketing). - */ - private class DynamicPartitionWriteTask( - desc: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - /** Flag saying whether or not the data to be written out is partitioned. */ - val isPartitioned = desc.partitionColumns.nonEmpty - - /** Flag saying whether or not the data to be written out is bucketed. */ - val isBucketed = desc.bucketIdExpression.isDefined - - assert(isPartitioned || isBucketed, - s"""DynamicPartitionWriteTask should be used for writing out data that's either - |partitioned or bucketed. In this case neither is true. - |WriteJobDescription: ${desc} - """.stripMargin) - - // currentWriter is initialized whenever we see a new key (partitionValues + BucketId) - private var currentWriter: OutputWriter = _ - - /** Trackers for computing various statistics on the data as it's being written out. */ - private val statsTrackers: Seq[WriteTaskStatsTracker] = - desc.statsTrackers.map(_.newTaskInstance()) - - /** Extracts the partition values out of an input row. */ - private lazy val getPartitionValues: InternalRow => UnsafeRow = { - val proj = UnsafeProjection.create(desc.partitionColumns, desc.allColumns) - row => proj(row) - } - - /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ - private lazy val partitionPathExpression: Expression = Concat( - desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val partitionName = ScalaUDF( - ExternalCatalogUtils.getPartitionPathString _, - StringType, - Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId)))) - if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) - }) - - /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns - * the partition string. */ - private lazy val getPartitionPath: InternalRow => String = { - val proj = UnsafeProjection.create(Seq(partitionPathExpression), desc.partitionColumns) - row => proj(row).getString(0) - } - - /** Given an input row, returns the corresponding `bucketId` */ - private lazy val getBucketId: InternalRow => Int = { - val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, desc.allColumns) - row => proj(row).getInt(0) - } - - /** Returns the data columns to be written given an input row */ - private val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) - - /** - * Opens a new OutputWriter given a partition key and/or a bucket id. - * If bucket id is specified, we will append it to the end of the file name, but before the - * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet - * - * @param partitionValues the partition which all tuples being written by this `OutputWriter` - * belong to - * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to - * @param fileCounter the number of files that have been written in the past for this specific - * partition. This is used to limit the max number of records written for a - * single file. The value should start from 0. - * @param updatedPartitions the set of updated partition paths, we should add the new partition - * path of this writer to it. - */ - private def newOutputWriter( - partitionValues: Option[InternalRow], - bucketId: Option[Int], - fileCounter: Int, - updatedPartitions: mutable.Set[String]): Unit = { - - val partDir = partitionValues.map(getPartitionPath(_)) - partDir.foreach(updatedPartitions.add) - - val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - - // This must be in a form that matches our bucketing format. See BucketingUtils. - val ext = f"$bucketIdStr.c$fileCounter%03d" + - desc.outputWriterFactory.getFileExtension(taskAttemptContext) - - val customPath = partDir.flatMap { dir => - desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) - } - val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) - } else { - committer.newTaskTempFile(taskAttemptContext, partDir, ext) - } - - currentWriter = desc.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = desc.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.foreach(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - // If anything below fails, we should abort the task. - var recordsInFile: Long = 0L - var fileCounter = 0 - val updatedPartitions = mutable.Set[String]() - var currentPartionValues: Option[UnsafeRow] = None - var currentBucketId: Option[Int] = None - - for (row <- iter) { - val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None - val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None - - if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { - // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartionValues != nextPartitionValues) { - currentPartionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartionValues.get)) - } - if (isBucketed) { - currentBucketId = nextBucketId - statsTrackers.foreach(_.newBucket(currentBucketId.get)) - } - - recordsInFile = 0 - fileCounter = 0 - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } else if (desc.maxRecordsPerFile > 0 && - recordsInFile >= desc.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - recordsInFile = 0 - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } - val outputRow = getOutputRow(row) - currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) - recordsInFile += 1 - } - releaseResources() - - ExecutedWriteSummary( - updatedPartitions = updatedPartitions.toSet, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } } - -/** - * Wrapper class for the metrics of writing data out. - * - * @param updatedPartitions the partitions updated during writing data out. Only valid - * for dynamic partition. - * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. - */ -case class ExecutedWriteSummary( - updatedPartitions: Set[String], - stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 28c36b6020d33..99fc78ff3e49b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -214,7 +214,7 @@ class FileScanRDD( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(_ => iterator.close()) + context.addTaskCompletionListener[Unit](_ => iterator.close()) iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d92..fe27b78bf3360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.FileSourceScanExec -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.util.collection.BitSet /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -50,6 +51,91 @@ import org.apache.spark.sql.execution.SparkPlan * and add it. Proceed to the next file. */ object FileSourceStrategy extends Strategy with Logging { + + // should prune buckets iff num buckets is greater than 1 and there is only one bucket column + private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = { + bucketSpec match { + case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1 + case None => false + } + } + + private def getExpressionBuckets( + expr: Expression, + bucketColumnName: String, + numBuckets: Int): BitSet = { + + def getBucketNumber(attr: Attribute, v: Any): Int = { + BucketingUtils.getBucketIdFromValue(attr, numBuckets, v) + } + + def getBucketSetFromIterable(attr: Attribute, iter: Iterable[Any]): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + iter + .map(v => getBucketNumber(attr, v)) + .foreach(bucketNum => matchedBuckets.set(bucketNum)) + matchedBuckets + } + + def getBucketSetFromValue(attr: Attribute, v: Any): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.set(getBucketNumber(attr, v)) + matchedBuckets + } + + expr match { + case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getBucketSetFromValue(a, v) + case expressions.In(a: Attribute, list) + if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) + case expressions.InSet(a: Attribute, hset) + if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow))) + case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => + getBucketSetFromValue(a, null) + case expressions.And(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) & + getExpressionBuckets(right, bucketColumnName, numBuckets) + case expressions.Or(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) | + getExpressionBuckets(right, bucketColumnName, numBuckets) + case _ => + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.setUntil(numBuckets) + matchedBuckets + } + } + + private def genBucketSet( + normalizedFilters: Seq[Expression], + bucketSpec: BucketSpec): Option[BitSet] = { + if (normalizedFilters.isEmpty) { + return None + } + + val bucketColumnName = bucketSpec.bucketColumnNames.head + val numBuckets = bucketSpec.numBuckets + + val normalizedFiltersAndExpr = normalizedFilters + .reduce(expressions.And) + val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName, + numBuckets) + + val numBucketsSelected = matchedBuckets.cardinality() + + logInfo { + s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets." + } + + // None means all the buckets need to be scanned + if (numBucketsSelected == numBuckets) { + None + } else { + Some(matchedBuckets) + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => @@ -76,9 +162,19 @@ object FileSourceStrategy extends Strategy with Logging { fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec + val bucketSet = if (shouldPruneBuckets(bucketSpec)) { + genBucketSet(normalizedFilters, bucketSpec.get) + } else { + None + } + val dataColumns = l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) @@ -108,6 +204,7 @@ object FileSourceStrategy extends Strategy with Logging { outputAttributes, outputSchema, partitionKeyFilters.toSeq, + bucketSet, dataFilters, table.map(_.identifier)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 739d1f456e3ec..dc5c2ff927e4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -162,7 +162,7 @@ object InMemoryFileIndex extends Logging { * * @return for each input path, the set of discovered files for the path */ - private def bulkListLeafFiles( + private[sql] def bulkListLeafFiles( paths: Seq[Path], hadoopConf: Configuration, filter: PathFilter, @@ -294,9 +294,12 @@ object InMemoryFileIndex extends Logging { if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles } - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + val missingFiles = mutable.ArrayBuffer.empty[String] + val filteredLeafStatuses = allLeafStatuses.filterNot( + status => shouldFilterOut(status.getPath.getName)) + val resolvedLeafStatuses = filteredLeafStatuses.flatMap { case f: LocatedFileStatus => - f + Some(f) // NOTE: // @@ -311,14 +314,27 @@ object InMemoryFileIndex extends Logging { // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), // which is very slow on some file system (RawLocalFileSystem, which is launch a // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) + try { + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + Some(lfs) + } catch { + case _: FileNotFoundException => + missingFiles += f.getPath.toString + None } - lfs } + + if (missingFiles.nonEmpty) { + logWarning( + s"the following files were missing during file scan:\n ${missingFiles.mkString("\n ")}") + } + + resolvedLeafStatuses } /** Checks if we should filter out this path name. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1..80d7608a22891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -38,9 +38,8 @@ case class InsertIntoDataSourceCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = Dataset.ofRows(sparkSession, query) - // Apply the schema of the existing table to the new data. - val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + // Data has been casted to the target relation's schema by the PreprocessTableInsertion rule. + relation.insert(data, overwrite) // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this // data source relation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index dd7ef0d15c140..2ae21b7df9823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -91,8 +92,12 @@ case class InsertIntoHadoopFsRelationCommand( val pathExists = fs.exists(qualifiedOutputPath) - val enableDynamicOverwrite = - sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + val parameters = CaseInsensitiveMap(options) + + val partitionOverwriteMode = parameters.get("partitionOverwriteMode") + .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase)) + .getOrElse(sparkSession.sessionState.conf.partitionOverwriteMode) + val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC // This config only makes sense when we are overwriting a partitioned dataset with dynamic // partition columns. val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && @@ -166,7 +171,15 @@ case class InsertIntoHadoopFsRelationCommand( // update metastore partition metadata - refreshUpdatedPartitions(updatedPartitionPaths) + if (updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty + && partitionColumns.length == staticPartitions.size) { + // Avoid empty static partition can't loaded to datasource table. + val staticPathFragment = + PartitioningUtils.getPathFragment(staticPartitions, partitionColumns) + refreshUpdatedPartitions(Set(staticPathFragment)) + } else { + refreshUpdatedPartitions(updatedPartitionPaths) + } // refresh cached files in FileIndex fileIndex.foreach(_.refresh()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f9a24806953e6..3183fd30e5e0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -284,6 +284,10 @@ object PartitioningUtils { }.mkString("/") } + def getPathFragment(spec: TablePartitionSpec, partitionColumns: Seq[Attribute]): String = { + getPathFragment(spec, StructType.fromAttributes(partitionColumns)) + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a @@ -410,7 +414,7 @@ object PartitioningUtils { val dateTry = Try { // try and parse the date, if no exception occurs this is a candidate to be resolved as // DateType - DateTimeUtils.getThreadLocalDateFormat.parse(raw) + DateTimeUtils.getThreadLocalDateFormat(DateTimeUtils.defaultTimeZone()).parse(raw) // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. // This can happen since DateFormat.parse may not use the entire text of the given string: // so if there are extra-characters after the date, it returns correctly. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 3b830accb83f0..16b2367bfdd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -55,7 +55,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 568e953a5db66..00b1b5dedb593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.CreatableRelationProvider -import org.apache.spark.util.Utils /** * Saves the results of `query` in to a data source. @@ -50,7 +49,7 @@ case class SaveIntoDataSourceCommand( } override def simpleString: String = { - val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap + val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4870d75fc5f08..2b86054c0ffcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -50,7 +51,11 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] + requiredSchema: StructType, + // Actual schema of data in the csv file + dataSchema: StructType, + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -110,7 +115,7 @@ abstract class CSVDataSource extends Serializable { } } -object CSVDataSource { +object CSVDataSource extends Logging { def apply(options: CSVOptions): CSVDataSource = { if (options.multiLine) { MultiLineCSVDataSource @@ -118,6 +123,65 @@ object CSVDataSource { TextInputCSVDataSource } } + + /** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param schema - provided (or inferred) schema to which CSV must conform. + * @param columnNames - names of CSV columns that must be checked against to the schema. + * @param fileName - name of CSV file that are currently checked. It is used in error messages. + * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column + * names are checked for conformance to the schema. In the case if + * the column name don't conform to the schema, an exception is thrown. + * @param caseSensitive - if it is set to `false`, comparison of column names and schema field + * names is not case sensitive. + */ + def checkHeaderColumnNames( + schema: StructType, + columnNames: Array[String], + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + if (columnNames != null) { + val fieldNames = schema.map(_.name).toIndexedSeq + val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) + var errorMessage: Option[String] = None + + if (headerLen == schemaSize) { + var i = 0 + while (errorMessage.isEmpty && i < headerLen) { + var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) + if (!caseSensitive) { + nameInSchema = nameInSchema.toLowerCase + nameInHeader = nameInHeader.toLowerCase + } + if (nameInHeader != nameInSchema) { + errorMessage = Some( + s"""|CSV header does not conform to the schema. + | Header: ${columnNames.mkString(", ")} + | Schema: ${fieldNames.mkString(", ")} + |Expected: ${fieldNames(i)} but found: ${columnNames(i)} + |CSV file: $fileName""".stripMargin) + } + i += 1 + } + } else { + errorMessage = Some( + s"""|Number of column in CSV header is not equal to number of fields in the schema: + | Header length: $headerLen, schema size: $schemaSize + |CSV file: $fileName""".stripMargin) + } + + errorMessage.foreach { msg => + if (enforceSchema) { + logWarning(msg) + } else { + throw new IllegalArgumentException(msg) + } + } + } + } } object TextInputCSVDataSource extends CSVDataSource { @@ -127,17 +191,37 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) } } - val shouldDropHeader = parser.options.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) + val hasHeader = parser.options.headerFlag && file.start == 0 + if (hasHeader) { + // Checking that column names in the header are matched to field names of the schema. + // The header will be removed from lines. + // Note: if there are only comments in the first block, the header would probably + // be not extracted. + CSVUtils.extractHeader(lines, parser.options).foreach { header => + val schema = if (columnPruning) requiredSchema else dataSchema + val columnNames = parser.tokenizer.parseLine(header) + CSVDataSource.checkHeaderColumnNames( + schema, + columnNames, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + } + + UnivocityParser.parseIterator(lines, parser, requiredSchema) } override def infer( @@ -161,7 +245,8 @@ object TextInputCSVDataSource extends CSVDataSource { val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) @@ -184,7 +269,8 @@ object TextInputCSVDataSource extends CSVDataSource { DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = options.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as[String](Encoders.STRING) } else { @@ -204,12 +290,26 @@ object MultiLineCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] = { + def checkHeader(header: Array[String]): Unit = { + val schema = if (columnPruning) requiredSchema else dataSchema + CSVDataSource.checkHeaderColumnNames( + schema, + header, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, - schema) + requiredSchema, + checkHeader) } override def infer( @@ -235,7 +335,8 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + val sampled = CSVUtils.sample(tokenRDD, parsedOptions) + CSVInferSchema.infer(sampled, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) @@ -248,7 +349,8 @@ object MultiLineCSVDataSource extends CSVDataSource { options: CSVOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) val name = paths.mkString(",") - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + options.parameters)) FileInputFormat.setInputPaths(job, paths: _*) val conf = job.getConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index e20977a4ec79f..9aad0bd55e736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv +import java.nio.charset.Charset + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ @@ -41,8 +43,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) val csvDataSource = CSVDataSource(parsedOptions) csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) } @@ -51,8 +55,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } @@ -62,9 +68,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - CSVUtils.verifySchema(dataSchema) val conf = job.getConfiguration - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) csvOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -91,12 +99,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - CSVUtils.verifySchema(dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val parsedOptions = new CSVOptions( options, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -122,6 +130,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { "df.filter($\"_corrupt_record\".isNotNull).count()." ) } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val columnPruning = sparkSession.sessionState.conf.csvColumnPruning (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -129,7 +139,14 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) + CSVDataSource(parsedOptions).readFile( + conf, + file, + parser, + requiredSchema, + dataSchema, + caseSensitive, + columnPruning) } } @@ -138,6 +155,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } + } private[csv] class CsvOutputWriter( @@ -146,7 +172,9 @@ private[csv] class CsvOutputWriter( context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val charset = Charset.forName(params.charset) + + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) private val gen = new UnivocityGenerator(dataSchema, writer, params) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index c16790630ce17..fab8d62da0c1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -27,17 +27,20 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ class CSVOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], + val columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { def this( parameters: Map[String, String], + columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String = "") = { this( CaseInsensitiveMap(parameters), + columnPruning, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) } @@ -150,6 +153,15 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + /** + * Forcibly apply the specified or inferred schema to datasource files. + * If the option is enabled, headers of CSV files will be ignored. + */ + val enforceSchema = getBool("enforceSchema", default = true) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -161,7 +173,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue(nullValue) + writerSettings.setEmptyValue("\"\"") writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -182,6 +194,7 @@ class CSVOptions( settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) settings.setNullValue(nullValue) + settings.setEmptyValue("") settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 72b053d2092ca..7ce65fa89b02d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -67,12 +68,8 @@ object CSVUtils { } } - /** - * Drop header line so that only data can remain. - * This is similar with `filterHeaderLine` above and currently being used in CSV reading path. - */ - def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - val nonEmptyLines = if (options.isCommentSet) { + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { val commentPrefix = options.comment.toString iter.dropWhile { line => line.trim.isEmpty || line.trim.startsWith(commentPrefix) @@ -80,11 +77,19 @@ object CSVUtils { } else { iter.dropWhile(_.trim.isEmpty) } - - if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) - iter } + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one @@ -113,22 +118,28 @@ object CSVUtils { } /** - * Verify if the schema is supported in CSV datasource. + * Sample CSV dataset as configured by `samplingRatio`. */ - def verifySchema(schema: StructType): Unit = { - def verifyType(dataType: DataType): Unit = dataType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | - DoubleType | BooleanType | _: DecimalType | TimestampType | - DateType | StringType => - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - case _ => - throw new UnsupportedOperationException( - s"CSV data source does not support ${dataType.simpleString} data type.") + def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) } - - schema.foreach(field => verifyType(field.dataType)) } + /** + * Sample CSV RDD as configured by `samplingRatio`. + */ + def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3d6cc30f2ba83..e15af425b2649 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.InputStream import java.math.BigDecimal -import java.text.NumberFormat -import java.util.Locale import scala.util.Try import scala.util.control.NonFatal @@ -35,19 +33,47 @@ import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String + +/** + * Constructs a parser for a given schema that translates CSV data to an [[InternalRow]]. + * + * @param dataSchema The CSV data schema that is specified by the user, or inferred from underlying + * data files. + * @param requiredSchema The schema of the data that should be output for each row. This should be a + * subset of the columns in dataSchema. + * @param options Configuration options for a CSV parser. + */ class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), - "requiredSchema should be the subset of schema.") + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), + s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " + + s"dataSchema (${dataSchema.catalogString}).") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + // This index is used to reorder parsed tokens + private val tokenIndexArr = + requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))).toArray + + // When column pruning is enabled, the parser only parses the required columns based on + // their positions in the data schema. + private val parsedSchema = if (options.columnPruning) requiredSchema else dataSchema + + val tokenizer = { + val parserSetting = options.asParserSettings + // When to-be-parsed schema is shorter than the to-be-read data schema, we let Univocity CSV + // parser select a sequence of fields for reading by their positions. + // if (options.columnPruning && requiredSchema.length < dataSchema.length) { + if (parsedSchema.length < dataSchema.length) { + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } private val row = new GenericInternalRow(requiredSchema.length) @@ -75,11 +101,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray + private val valueConverters: Array[ValueConverter] = { + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } /** @@ -186,15 +209,21 @@ class UnivocityParser( */ def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + private val getToken = if (options.columnPruning) { + (tokens: Array[String], index: Int) => tokens(index) + } else { + (tokens: Array[String], index: Int) => tokens(tokenIndexArr(index)) + } + private def convert(tokens: Array[String]): InternalRow = { - if (tokens.length != schema.length) { + if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. - val checkedTokens = if (schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) + val checkedTokens = if (parsedSchema.length > tokens.length) { + tokens ++ new Array[String](parsedSchema.length - tokens.length) } else { - tokens.take(schema.length) + tokens.take(parsedSchema.length) } def getPartialResult(): Option[InternalRow] = { try { @@ -211,10 +240,11 @@ class UnivocityParser( new RuntimeException("Malformed CSV record")) } else { try { + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to convert the tokens that correspond to the required columns. var i = 0 while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + row(i) = valueConverters(i).apply(getToken(tokens, i)) i += 1 } row @@ -248,14 +278,16 @@ private[csv] object UnivocityParser { inputStream: InputStream, shouldDropHeader: Boolean, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + schema: StructType, + checkHeader: Array[String] => Unit): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( input => Seq(parser.convert(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) - convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) + convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten } @@ -263,11 +295,14 @@ private[csv] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer: CsvParser, + checkHeader: Array[String] => Unit = _ => ())( + convert: Array[String] => T) = new Iterator[T] { tokenizer.beginParsing(inputStream) private var nextRecord = { if (shouldDropHeader) { - tokenizer.parseNext() + val firstRecord = tokenizer.parseNext() + checkHeader(firstRecord) } tokenizer.parseNext() } @@ -289,27 +324,18 @@ private[csv] object UnivocityParser { */ def parseIterator( lines: Iterator[String], - shouldDropHeader: Boolean, parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { val options = parser.options - val linesWithoutHeader = if (shouldDropHeader) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, options) - } else { - lines - } - - val filteredLines: Iterator[String] = - CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index b4e5d169066d9..7dfbb9d8b5c05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType * Options for the JDBC data source. */ class JDBCOptions( - @transient private val parameters: CaseInsensitiveMap[String]) + @transient val parameters: CaseInsensitiveMap[String]) extends Serializable { import JDBCOptions._ @@ -65,11 +65,31 @@ class JDBCOptions( // Required parameters // ------------------------------------------------------------ require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") - require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL val url = parameters(JDBC_URL) - // name of table - val table = parameters(JDBC_TABLE_NAME) + // table name or a table subquery. + val tableOrQuery = (parameters.get(JDBC_TABLE_NAME), parameters.get(JDBC_QUERY_STRING)) match { + case (Some(name), Some(subquery)) => + throw new IllegalArgumentException( + s"Both '$JDBC_TABLE_NAME' and '$JDBC_QUERY_STRING' can not be specified at the same time." + ) + case (None, None) => + throw new IllegalArgumentException( + s"Option '$JDBC_TABLE_NAME' or '$JDBC_QUERY_STRING' is required." + ) + case (Some(name), None) => + if (name.isEmpty) { + throw new IllegalArgumentException(s"Option '$JDBC_TABLE_NAME' can not be empty.") + } else { + name.trim + } + case (None, Some(subquery)) => + if (subquery.isEmpty) { + throw new IllegalArgumentException(s"Option `$JDBC_QUERY_STRING` can not be empty.") + } else { + s"(${subquery}) __SPARK_GEN_JDBC_SUBQUERY_NAME_${curId.getAndIncrement()}" + } + } // ------------------------------------------------------------ // Optional parameters @@ -89,15 +109,19 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt) + // the number of seconds the driver will wait for a Statement object to execute to the given + // number of seconds. Zero means there is no limit. + val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt + // ------------------------------------------------------------ // Optional parameters only for reading // ------------------------------------------------------------ // the column used to partition val partitionColumn = parameters.get(JDBC_PARTITION_COLUMN) // the lower bound of partition column - val lowerBound = parameters.get(JDBC_LOWER_BOUND).map(_.toLong) + val lowerBound = parameters.get(JDBC_LOWER_BOUND) // the upper bound of the partition column - val upperBound = parameters.get(JDBC_UPPER_BOUND).map(_.toLong) + val upperBound = parameters.get(JDBC_UPPER_BOUND) // numPartitions is also used for data source writing require((partitionColumn.isEmpty && lowerBound.isEmpty && upperBound.isEmpty) || (partitionColumn.isDefined && lowerBound.isDefined && upperBound.isDefined && @@ -105,6 +129,20 @@ class JDBCOptions( s"When reading JDBC data sources, users need to specify all or none for the following " + s"options: '$JDBC_PARTITION_COLUMN', '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', " + s"and '$JDBC_NUM_PARTITIONS'") + + require(!(parameters.get(JDBC_QUERY_STRING).isDefined && partitionColumn.isDefined), + s""" + |Options '$JDBC_QUERY_STRING' and '$JDBC_PARTITION_COLUMN' can not be specified together. + |Please define the query using `$JDBC_TABLE_NAME` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + ) + val fetchSize = { val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt require(size >= 0, @@ -119,6 +157,8 @@ class JDBCOptions( // ------------------------------------------------------------ // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean + + val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options @@ -143,9 +183,35 @@ class JDBCOptions( } // An option to execute custom SQL before fetching data from the remote DB val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) + + // An option to allow/disallow pushing down predicate into JDBC data source + val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean +} + +class JdbcOptionsInWrite( + @transient override val parameters: CaseInsensitiveMap[String]) + extends JDBCOptions(parameters) { + + import JDBCOptions._ + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + def this(url: String, table: String, parameters: Map[String, String]) = { + this(CaseInsensitiveMap(parameters ++ Map( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> table))) + } + + require( + parameters.get(JDBC_TABLE_NAME).isDefined, + s"Option '$JDBC_TABLE_NAME' is required. " + + s"Option '$JDBC_QUERY_STRING' is not applicable while writing.") + + val table = parameters(JDBC_TABLE_NAME) } object JDBCOptions { + private val curId = new java.util.concurrent.atomic.AtomicLong(0L) private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { @@ -155,17 +221,21 @@ object JDBCOptions { val JDBC_URL = newOption("url") val JDBC_TABLE_NAME = newOption("dbtable") + val JDBC_QUERY_STRING = newOption("query") val JDBC_DRIVER_CLASS = newOption("driver") val JDBC_PARTITION_COLUMN = newOption("partitionColumn") val JDBC_LOWER_BOUND = newOption("lowerBound") val JDBC_UPPER_BOUND = newOption("upperBound") val JDBC_NUM_PARTITIONS = newOption("numPartitions") + val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") + val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") + val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 05326210f3242..16b493892e3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -51,12 +51,13 @@ object JDBCRDD extends Logging { */ def resolveTable(options: JDBCOptions): StructType = { val url = options.url - val table = options.table + val table = options.tableOrQuery val dialect = JdbcDialects.get(url) val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { + statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() try { JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) @@ -264,7 +265,7 @@ private[jdbc] class JDBCRDD( closed = true } - context.addTaskCompletionListener{ context => close() } + context.addTaskCompletionListener[Unit]{ context => close() } val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] @@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD( val statement = conn.prepareStatement(sql) logInfo(s"Executing sessionInitStatement: $sql") try { + statement.setQueryTimeout(options.queryTimeout) statement.execute() } finally { statement.close() @@ -294,10 +296,11 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) + stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b23e5a7722004..f15014442e3fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -17,21 +17,27 @@ package org.apache.spark.sql.execution.datasources.jdbc +import java.sql.{Date, Timestamp} + import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType} +import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. */ private[sql] case class JDBCPartitioningInfo( column: String, + columnType: DataType, lowerBound: Long, upperBound: Long, numPartitions: Int) @@ -48,10 +54,44 @@ private[sql] object JDBCRelation extends Logging { * Null value predicate is added to the first partition where clause to include * the rows with null value for the partitions column. * - * @param partitioning partition information to generate the where clause for each partition + * @param schema resolved schema of a JDBC table + * @param resolver function used to determine if two identifiers are equal + * @param timeZoneId timezone ID to be used if a partition column type is date or timestamp + * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { + def columnPartition( + schema: StructType, + resolver: Resolver, + timeZoneId: String, + jdbcOptions: JDBCOptions): Array[Partition] = { + val partitioning = { + import JDBCOptions._ + + val partitionColumn = jdbcOptions.partitionColumn + val lowerBound = jdbcOptions.lowerBound + val upperBound = jdbcOptions.upperBound + val numPartitions = jdbcOptions.numPartitions + + if (partitionColumn.isEmpty) { + assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not " + + s"specified, '$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") + null + } else { + assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, + s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + + s"'$JDBC_NUM_PARTITIONS' are also required") + + val (column, columnType) = verifyAndGetNormalizedPartitionColumn( + schema, partitionColumn.get, resolver, jdbcOptions) + + val lowerBoundValue = toInternalBoundValue(lowerBound.get, columnType) + val upperBoundValue = toInternalBoundValue(upperBound.get, columnType) + JDBCPartitioningInfo( + column, columnType, lowerBoundValue, upperBoundValue, numPartitions.get) + } + } + if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -63,6 +103,8 @@ private[sql] object JDBCRelation extends Logging { "Operation not allowed: the lower bound of partitioning column is larger than the upper " + s"bound. Lower bound: $lowerBound; Upper bound: $upperBound") + val boundValueToString: Long => String = + toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId) val numPartitions = if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */ (upperBound - lowerBound) < 0) { @@ -71,21 +113,25 @@ private[sql] object JDBCRelation extends Logging { logWarning("The number of partitions is reduced because the specified number of " + "partitions is less than the difference between upper bound and lower bound. " + s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " + - s"partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " + - s"Upper bound: $upperBound.") + s"partitions: ${partitioning.numPartitions}; " + + s"Lower bound: ${boundValueToString(lowerBound)}; " + + s"Upper bound: ${boundValueToString(upperBound)}.") upperBound - lowerBound } // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = partitioning.column + var i: Int = 0 - var currentValue: Long = lowerBound + val column = partitioning.column + var currentValue = lowerBound val ans = new ArrayBuffer[Partition]() while (i < numPartitions) { - val lBound = if (i != 0) s"$column >= $currentValue" else null + val lBoundValue = boundValueToString(currentValue) + val lBound = if (i != 0) s"$column >= $lBoundValue" else null currentValue += stride - val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val uBoundValue = boundValueToString(currentValue) + val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null val whereClause = if (uBound == null) { lBound @@ -97,32 +143,109 @@ private[sql] object JDBCRelation extends Logging { ans += JDBCPartition(whereClause, i) i = i + 1 } - ans.toArray + val partitions = ans.toArray + logInfo(s"Number of partitions: $numPartitions, WHERE clauses of these partitions: " + + partitions.map(_.asInstanceOf[JDBCPartition].whereClause).mkString(", ")) + partitions } -} -private[sql] case class JDBCRelation( - parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) - extends BaseRelation - with PrunedFilteredScan - with InsertableRelation { + // Verify column name and type based on the JDBC resolved schema + private def verifyAndGetNormalizedPartitionColumn( + schema: StructType, + columnName: String, + resolver: Resolver, + jdbcOptions: JDBCOptions): (String, DataType) = { + val dialect = JdbcDialects.get(jdbcOptions.url) + val column = schema.find { f => + resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName) + }.getOrElse { + throw new AnalysisException(s"User-defined partition column $columnName not " + + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + } + column.dataType match { + case _: NumericType | DateType | TimestampType => + case _ => + throw new AnalysisException( + s"Partition column type should be ${NumericType.simpleString}, " + + s"${DateType.catalogString}, or ${TimestampType.catalogString}, but " + + s"${column.dataType.catalogString} found.") + } + (dialect.quoteIdentifier(column.name), column.dataType) + } - override def sqlContext: SQLContext = sparkSession.sqlContext + private def toInternalBoundValue(value: String, columnType: DataType): Long = columnType match { + case _: NumericType => value.toLong + case DateType => DateTimeUtils.fromJavaDate(Date.valueOf(value)).toLong + case TimestampType => DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(value)) + } - override val needConversion: Boolean = false + private def toBoundValueInWhereClause( + value: Long, + columnType: DataType, + timeZoneId: String): String = { + def dateTimeToString(): String = { + val timeZone = DateTimeUtils.getTimeZone(timeZoneId) + val dateTimeStr = columnType match { + case DateType => DateTimeUtils.dateToString(value.toInt, timeZone) + case TimestampType => DateTimeUtils.timestampToString(value, timeZone) + } + s"'$dateTimeStr'" + } + columnType match { + case _: NumericType => value.toString + case DateType | TimestampType => dateTimeToString() + } + } - override val schema: StructType = { + /** + * Takes a (schema, table) specification and returns the table's Catalyst schema. + * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the + * custom schema's type. + * + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url, table and other information. + * @return resolved Catalyst schema of a JDBC table + */ + def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = { val tableSchema = JDBCRDD.resolveTable(jdbcOptions) jdbcOptions.customSchema match { case Some(customSchema) => JdbcUtils.getCustomSchema( - tableSchema, customSchema, sparkSession.sessionState.conf.resolver) + tableSchema, customSchema, resolver) case None => tableSchema } } + /** + * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema. + */ + def apply( + parts: Array[Partition], + jdbcOptions: JDBCOptions)( + sparkSession: SparkSession): JDBCRelation = { + val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sparkSession) + } +} + +private[sql] case class JDBCRelation( + override val schema: StructType, + parts: Array[Partition], + jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) + extends BaseRelation + with PrunedFilteredScan + with InsertableRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override val needConversion: Boolean = false + // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { - filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + if (jdbcOptions.pushDownPredicate) { + filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + } else { + filters + } } override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { @@ -139,12 +262,12 @@ private[sql] case class JDBCRelation( override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + .jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties) } override def toString: String = { val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else "" // credentials should not be included in the plan output, table information is sufficient. - s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo + s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index cc506e51bd0c6..e7456f9c8ed0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -29,27 +29,12 @@ class JdbcRelationProvider extends CreatableRelationProvider override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - import JDBCOptions._ - val jdbcOptions = new JDBCOptions(parameters) - val partitionColumn = jdbcOptions.partitionColumn - val lowerBound = jdbcOptions.lowerBound - val upperBound = jdbcOptions.upperBound - val numPartitions = jdbcOptions.numPartitions - - val partitionInfo = if (partitionColumn.isEmpty) { - assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " + - s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") - null - } else { - assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, - s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + - s"'$JDBC_NUM_PARTITIONS' are also required") - JDBCPartitioningInfo( - partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) + val resolver = sqlContext.conf.resolver + val timeZoneId = sqlContext.conf.sessionLocalTimeZone + val schema = JDBCRelation.getSchema(resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } override def createRelation( @@ -57,7 +42,7 @@ class JdbcRelationProvider extends CreatableRelationProvider mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = { - val options = new JDBCOptions(parameters) + val options = new JdbcOptionsInWrite(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis val conn = JdbcUtils.createConnectionFactory(options)() @@ -73,7 +58,7 @@ class JdbcRelationProvider extends CreatableRelationProvider saveTable(df, tableSchema, isCaseSensitive, options) } else { // Otherwise, do not truncate the table, instead drop and recreate it - dropTable(conn, options.table) + dropTable(conn, options.table, options) createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } @@ -84,7 +69,8 @@ class JdbcRelationProvider extends CreatableRelationProvider case SaveMode.ErrorIfExists => throw new AnalysisException( - s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.") + s"Table or view '${options.table}' already exists. " + + s"SaveMode: ErrorIfExists.") case SaveMode.Ignore => // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e6dc2fda4eb1b..edea549748b47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -67,7 +67,7 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, options: JDBCOptions): Boolean = { + def tableExists(conn: Connection, options: JdbcOptionsInWrite): Boolean = { val dialect = JdbcDialects.get(options.url) // Somewhat hacky, but there isn't a good way to identify whether a table exists for all @@ -76,6 +76,7 @@ object JdbcUtils extends Logging { Try { val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) statement.executeQuery() } finally { statement.close() @@ -86,9 +87,10 @@ object JdbcUtils extends Logging { /** * Drops a table from the JDBC database. */ - def dropTable(conn: Connection, table: String): Unit = { + def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(s"DROP TABLE $table") } finally { statement.close() @@ -98,11 +100,17 @@ object JdbcUtils extends Logging { /** * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + def truncateTable(conn: Connection, options: JdbcOptionsInWrite): Unit = { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { - statement.executeUpdate(dialect.getTruncateQuery(options.table)) + statement.setQueryTimeout(options.queryTimeout) + val truncateQuery = if (options.isCascadeTruncate.isDefined) { + dialect.getTruncateQuery(options.table, options.isCascadeTruncate) + } else { + dialect.getTruncateQuery(options.table) + } + statement.executeUpdate(truncateQuery) } finally { statement.close() } @@ -172,7 +180,7 @@ object JdbcUtils extends Logging { private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( - throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) + throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}")) } /** @@ -252,8 +260,9 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) try { - val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) + val statement = conn.prepareStatement(dialect.getSchemaQuery(options.tableOrQuery)) try { + statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) } catch { case _: SQLException => None @@ -476,7 +485,7 @@ object JdbcUtils extends Logging { case LongType if metadata.contains("binarylong") => throw new IllegalArgumentException(s"Unsupported array element " + - s"type ${dt.simpleString} based on binary") + s"type ${dt.catalogString} based on binary") case ArrayType(_, _) => throw new IllegalArgumentException("Nested arrays unsupported") @@ -490,7 +499,7 @@ object JdbcUtils extends Logging { array => new GenericArrayData(elementConversion.apply(array.getArray))) row.update(pos, array) - case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.catalogString}") } private def nullSafeConvert[T](input: T, f: T => Any): Any = { @@ -596,7 +605,8 @@ object JdbcUtils extends Logging { insertStmt: String, batchSize: Int, dialect: JdbcDialect, - isolationLevel: Int): Iterator[Byte] = { + isolationLevel: Int, + options: JDBCOptions): Iterator[Byte] = { val conn = getConnection() var committed = false @@ -637,6 +647,9 @@ object JdbcUtils extends Logging { try { var rowCount = 0 + + stmt.setQueryTimeout(options.queryTimeout) + while (iterator.hasNext) { val row = iterator.next() var i = 0 @@ -801,7 +814,7 @@ object JdbcUtils extends Logging { df: DataFrame, tableSchema: Option[StructType], isCaseSensitive: Boolean, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val url = options.url val table = options.table val dialect = JdbcDialects.get(url) @@ -819,7 +832,8 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.rdd.foreachPartition(iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) + getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, + options) ) } @@ -829,7 +843,7 @@ object JdbcUtils extends Logging { def createTable( conn: Connection, df: DataFrame, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val strSchema = schemaString( df, options.url, options.createTableColumnTypes) val table = options.table @@ -841,6 +855,7 @@ object JdbcUtils extends Logging { val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(sql) } finally { statement.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5769c09c9a1d9..76f58371ae264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -31,11 +31,12 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,32 +93,33 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset( - sparkSession, inputPaths, parsedOptions.lineSeparator) + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) - JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd + val rowParser = parsedOptions.encoding.map { enc => + CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) + }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) + + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - lineSeparator: Option[String]): Dataset[String] = { - val textOptions = lineSeparator.map { lineSep => - Map(TextOptions.LINE_SEPARATOR -> lineSep) - }.getOrElse(Map.empty[String, String]) - - val paths = inputPaths.map(_.getPath.toString) + parsedOptions: JSONOptions): Dataset[String] = { sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, - options = textOptions + options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -128,12 +130,17 @@ object TextInputJsonDataSource extends JsonDataSource { parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) + val textParser = parser.options.encoding + .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) + .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) + val safeParser = new FailureSafeParser[Text]( - input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) linesReader.flatMap(safeParser.parse) } @@ -151,16 +158,24 @@ object MultiLineJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, createParser) + val parser = parsedOptions.encoding + .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) + .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) + + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + parsedOptions.parameters)) val conf = job.getConfiguration val name = paths.mkString(",") FileInputFormat.setInputPaths(job, paths: _*) @@ -175,11 +190,18 @@ object MultiLineJsonDataSource extends JsonDataSource { .values } - private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { - val path = new Path(record.getPath()) - CreateJacksonParser.inputStream( - jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) + private def dataToInputStream(dataStream: PortableDataStream): InputStream = { + val path = new Path(dataStream.getPath()) + CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path) + } + + private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream)) + } + + private def createParser(enc: String, jsonFactory: JsonFactory, + stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream)) } override def readFile( @@ -194,12 +216,16 @@ object MultiLineJsonDataSource extends JsonDataSource { UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } + val streamParser = parser.options.encoding + .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream)) + .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream)) val safeParser = new FailureSafeParser[InputStream]( - input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) safeParser.parse( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0862c746fffad..a9241afba537b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -24,11 +26,11 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -38,7 +40,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -50,7 +52,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -97,7 +99,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -142,6 +144,23 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => true + + case _ => false + } } private[json] class JsonOutputWriter( @@ -151,7 +170,18 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val encoding = options.encoding match { + case Some(charsetName) => Charset.forName(charsetName) + case None => StandardCharsets.UTF_8 + } + + if (JSONOptionsInRead.blacklist.contains(encoding)) { + logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" + + " which can be read back by Spark only if multiLine is enabled.") + } + + private val writer = CodecStreams.createOutputStreamWriter( + context, new Path(path), encoding) // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 1de2ca2914c44..4574f8247af54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -59,6 +59,19 @@ private[sql] object OrcFileFormat { def checkFieldNames(names: Seq[String]): Unit = { names.foreach(checkFieldName) } + + def getQuotedSchemaString(dataType: DataType): String = dataType match { + case _: AtomicType => dataType.catalogString + case StructType(fields) => + fields.map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}") + .mkString("struct<", ",", ">") + case ArrayType(elementType, _) => + s"array<${getQuotedSchemaString(elementType)}>" + case MapType(keyType, valueType, _) => + s"map<${getQuotedSchemaString(keyType)},${getQuotedSchemaString(valueType)}>" + case _ => // UDT and others + dataType.catalogString + } } /** @@ -93,7 +106,7 @@ class OrcFileFormat val conf = job.getConfiguration - conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcFileFormat.getQuotedSchemaString(dataSchema)) conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) @@ -192,7 +205,7 @@ class OrcFileFormat // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( @@ -207,7 +220,7 @@ class OrcFileFormat val orcRecordReader = new OrcInputFormat[OrcStruct] .createRecordReader(fileSplit, taskAttemptContext) val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) @@ -224,4 +237,21 @@ class OrcFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4f44ae4fa1d71..c4c3b3053a3b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -98,7 +98,7 @@ private[orc] object OrcFilters { case DateType => PredicateLeaf.Type.DATE case TimestampType => PredicateLeaf.Type.TIMESTAMP case _: DecimalType => PredicateLeaf.Type.DECIMAL - case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.catalogString}") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 899af0750cadf..90d1268028096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -223,6 +223,6 @@ class OrcSerializer(dataSchema: StructType) { * Return a Orc value object for the given Spark schema. */ private def createOrcValue(dataType: DataType) = { - OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString)) + OrcStruct.createValue(TypeDescription.fromString(OrcFileFormat.getQuotedSchemaString(dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 460194ba61c8b..ac062fdc092ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -79,9 +79,10 @@ object OrcUtils extends Logging { val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. - files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema => - logDebug(s"Reading schema from file $files, got Hive schema string: $schema") - CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] + files.toIterator.map(file => readSchema(file.getPath, conf, ignoreCorruptFiles)).collectFirst { + case Some(schema) => + logDebug(s"Reading schema from file $files, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] } } @@ -104,7 +105,7 @@ object OrcUtils extends Logging { // This is a ORC file written by Hive, no field names in the physical schema, assume the // physical schema maps to the data scheme by index. assert(orcFieldNames.length <= dataSchema.length, "The given data schema " + - s"${dataSchema.simpleString} has less fields than the actual ORC physical schema, " + + s"${dataSchema.catalogString} has less fields than the actual ORC physical schema, " + "no idea which columns were dropped, fail to read.") Some(requiredSchema.fieldNames.map { name => val index = dataSchema.fieldIndex(name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d8f47eec952de..d7eb14356b8b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -22,7 +22,6 @@ import java.net.URI import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.parallel.ForkJoinTaskSupport import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration @@ -34,6 +33,7 @@ import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType @@ -77,7 +77,6 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) val conf = ContextUtil.getConfiguration(job) @@ -125,16 +124,17 @@ class ParquetFileFormat conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) // SPARK-15719: Disables writing Parquet summary files by default. - if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { - conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) } - if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (ParquetOutputFormat.getJobSummaryLevel(conf) == JobSummaryLevel.NONE && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { // output summary is requested, but the class is not a Parquet Committer logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + s" create job summaries. " + - s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") } new OutputWriterFactory { @@ -310,6 +310,9 @@ class ParquetFileFormat hadoopConf.set( SQLConf.SESSION_LOCAL_TIMEZONE.key, sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) @@ -333,37 +336,28 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) - val enableRecordFilter: Boolean = - sparkSession.sessionState.conf.parquetRecordFilterEnabled - val timestampConversion: Boolean = - sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize - val enableParquetFilterPushDown: Boolean = - sparkSession.sessionState.conf.parquetFilterPushDown + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - // Try to push down filters when filter push-down is enabled. - val pushed = if (enableParquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val filePath = fileSplit.getPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( - fileSplit.getPath, + filePath, fileSplit.getStart, fileSplit.getStart + fileSplit.getLength, fileSplit.getLength, @@ -371,16 +365,34 @@ class ParquetFileFormat null) val sharedConf = broadcastedHadoopConf.value.value + + lazy val footerFileMetaData = + ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, + pushDownStringStartWith, pushDownInFilterThreshold) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(parquetSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. - def isCreatedByParquetMr(): Boolean = { - val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS) - footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr") - } + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy().startsWith("parquet-mr") + val convertTz = - if (timestampConversion && !isCreatedByParquetMr()) { + if (timestampConversion && !isCreatedByParquetMr) { Some(DateTimeUtils.getTimeZone(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) } else { None @@ -401,7 +413,7 @@ class ParquetFileFormat convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) val iter = new RecordReaderIterator(vectorizedReader) // SPARK-23457 Register a task completion lister before `initialization`. - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) @@ -422,7 +434,7 @@ class ParquetFileFormat } val iter = new RecordReaderIterator(reader) // SPARK-23457 Register a task completion lister before `initialization`. - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) reader.initialize(split, hadoopAttemptContext) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes @@ -442,6 +454,21 @@ class ParquetFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } } object ParquetFileFormat extends Logging { @@ -507,30 +534,23 @@ object ParquetFileFormat extends Logging { conf: Configuration, partFiles: Seq[FileStatus], ignoreCorruptFiles: Boolean): Seq[Footer] = { - val parFiles = partFiles.par - val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8) - parFiles.tasksupport = new ForkJoinTaskSupport(pool) - try { - parFiles.flatMap { currentFile => - try { - // Skips row group information since we only need the schema. - // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, - // when it can't read the footer. - Some(new Footer(currentFile.getPath(), - ParquetFileReader.readFooter( - conf, currentFile, SKIP_ROW_GROUPS))) - } catch { case e: RuntimeException => - if (ignoreCorruptFiles) { - logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) - None - } else { - throw new IOException(s"Could not read footer for file: $currentFile", e) - } + ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile => + try { + // Skips row group information since we only need the schema. + // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, + // when it can't read the footer. + Some(new Footer(currentFile.getPath(), + ParquetFileReader.readFooter( + conf, currentFile, SKIP_ROW_GROUPS))) + } catch { case e: RuntimeException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) + None + } else { + throw new IOException(s"Could not read footer for file: $currentFile", e) } - }.seq - } finally { - pool.shutdown() - } + } + }.flatten } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index ccc8306866d68..58b4a769fcb62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,196 +17,399 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.sql.Date +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources -import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] object ParquetFilters { +private[parquet] class ParquetFilters( + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int) { + + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + length: Int, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } - private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } - private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } - private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.lt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.lt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.ltEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.ltEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.gt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.gtEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gtEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } /** * Returns a map from name of the column to the data type, if predicate push down applies. */ - private def getFieldMap(dataType: DataType): Map[String, DataType] = dataType match { - case StructType(fields) => + private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match { + case m: MessageType => // Here we don't flatten the fields in the nested schema but just look up through // root fields. Currently, accessing to nested fields does not push down filters // and it does not support to create filters for them. - fields.map(f => f.name -> f.dataType).toMap - case _ => Map.empty[String, DataType] + m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetSchemaType( + f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata) + }.toMap + case _ => Map.empty[String, ParquetSchemaType] } /** * Converts data sources filters to Parquet filter predicates. */ - def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { + def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToType = getFieldMap(schema) + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToType(name) match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => value.isInstanceOf[Date] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } + // Parquet does not allow dots in the column name because dots are used as a column path // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates // with missing columns. The incorrect results could be got from Parquet when we push down // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. - def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") + def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + } // NOTE: // @@ -224,29 +427,29 @@ private[parquet] object ParquetFilters { // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) if canMakeFilterOn(name) => + case sources.IsNull(name) if canMakeFilterOn(name, null) => makeEq.lift(nameToType(name)).map(_(name, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name) => + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => makeNotEq.lift(nameToType(name)).map(_(name, null)) - case sources.EqualTo(name, value) if canMakeFilterOn(name) => + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name) => + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.LessThan(name, value) if canMakeFilterOn(name) => + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => makeLt.lift(nameToType(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeLtEq.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThan(name, value) if canMakeFilterOn(name) => + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => makeGt.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeGtEq.lift(nameToType(name)).map(_(name, value)) case sources.And(lhs, rhs) => @@ -271,6 +474,44 @@ private[parquet] object ParquetFilters { case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.In(name, values) if canMakeFilterOn(name, values.head) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct.flatMap { v => + makeEq.lift(nameToType(name)).map(_(name, v)) + }.reduceLeftOption(FilterApi.or) + + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined(binaryColumn(name), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + UTF8String.fromBytes(value.getBytes).startsWith( + UTF8String.fromBytes(strToBinary.getBytes)) + } + } + ) + } + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index f36a89a4c3c5f..9cfc30725f03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -81,7 +81,10 @@ object ParquetOptions { "uncompressed" -> CompressionCodecName.UNCOMPRESSED, "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, - "lzo" -> CompressionCodecName.LZO) + "lzo" -> CompressionCodecName.LZO, + "lz4" -> CompressionCodecName.LZ4, + "brotli" -> CompressionCodecName.BROTLI, + "zstd" -> CompressionCodecName.ZSTD) def getParquetCompressionCodecName(name: String): String = { shortParquetCompressionCodecNames(name).name() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 40ce5d5e0564e..3319e73f2b313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.util.{Map => JMap, TimeZone} +import java.util.{Locale, Map => JMap, TimeZone} import scala.collection.JavaConverters._ @@ -30,6 +30,7 @@ import org.apache.parquet.schema.Type.Repetition import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -71,8 +72,10 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) StructType.fromString(schemaString) } - val parquetRequestedSchema = - ParquetReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) + val caseSensitive = context.getConfiguration.getBoolean(SQLConf.CASE_SENSITIVE.key, + SQLConf.CASE_SENSITIVE.defaultValue.get) + val parquetRequestedSchema = ParquetReadSupport.clipParquetSchema( + context.getFileSchema, catalystRequestedSchema, caseSensitive) new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) } @@ -117,8 +120,12 @@ private[parquet] object ParquetReadSupport { * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist * in `catalystSchema`, and adding those only exist in `catalystSchema`. */ - def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = { - val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema) + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + val clippedParquetFields = clipParquetGroupFields( + parquetSchema.asGroupType(), catalystSchema, caseSensitive) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -129,20 +136,21 @@ private[parquet] object ParquetReadSupport { } } - private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { + private def clipParquetType( + parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType) + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t) + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able @@ -168,14 +176,15 @@ private[parquet] object ParquetReadSupport { * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a * [[StructType]]. */ - private def clipParquetListType(parquetList: GroupType, elementType: DataType): Type = { + private def clipParquetListType( + parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) // Unannotated repeated group should be interpreted as required list of required element, so // list element type is just the group itself. Clip it. if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType) + clipParquetType(parquetList, elementType, caseSensitive) } else { assert( parquetList.getOriginalType == OriginalType.LIST, @@ -207,7 +216,7 @@ private[parquet] object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(OriginalType.LIST) - .addField(clipParquetType(repeatedGroup, elementType)) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) .named(parquetList.getName) } else { // Otherwise, the repeated field's type is the element type with the repeated field's @@ -218,7 +227,7 @@ private[parquet] object ParquetReadSupport { .addField( Types .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType)) + .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) .named(repeatedGroup.getName)) .named(parquetList.getName) } @@ -231,7 +240,10 @@ private[parquet] object ParquetReadSupport { * a [[StructType]]. */ private def clipParquetMapType( - parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { + parquetMap: GroupType, + keyType: DataType, + valueType: DataType, + caseSensitive: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -243,8 +255,8 @@ private[parquet] object ParquetReadSupport { Types .repeatedGroup() .as(repeatedGroup.getOriginalType) - .addField(clipParquetType(parquetKeyType, keyType)) - .addField(clipParquetType(parquetValueType, valueType)) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) .named(repeatedGroup.getName) Types @@ -262,8 +274,9 @@ private[parquet] object ParquetReadSupport { * [[MessageType]]. Because it's legal to construct an empty requested schema for column * pruning. */ - private def clipParquetGroup(parquetRecord: GroupType, structType: StructType): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType) + private def clipParquetGroup( + parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getOriginalType) @@ -277,14 +290,35 @@ private[parquet] object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType): Seq[Type] = { - val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - structType.map { f => - parquetFieldMap - .get(f.name) - .map(clipParquetType(_, f.dataType)) - .getOrElse(toParquet.convertField(f)) + if (caseSensitive) { + val caseSensitiveParquetFieldMap = + parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + structType.map { f => + caseSensitiveParquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType, caseSensitive)) + .getOrElse(toParquet.convertField(f)) + } + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + structType.map { f => + caseInsensitiveParquetFieldMap + .get(f.name.toLowerCase(Locale.ROOT)) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw new RuntimeException(s"""Found duplicate field(s) "${f.name}": """ + + s"$parquetTypesString in case-insensitive mode") + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + } + }.getOrElse(toParquet.convertField(f)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c61be077d309f..8ce8a86d2f026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -26,7 +26,6 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.maxPrecisionForBytes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -171,7 +170,7 @@ class ParquetToSparkSchemaConverter( case FIXED_LEN_BYTE_ARRAY => originalType match { - case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) + case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) case INTERVAL => typeNotImplemented() case _ => illegalType() } @@ -411,7 +410,7 @@ class SparkToParquetSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(ParquetSchemaConverter.minBytesForPrecision(precision)) + .length(Decimal.minBytesForPrecision(precision)) .named(field.name) // ======================== @@ -445,7 +444,7 @@ class SparkToParquetSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(ParquetSchemaConverter.minBytesForPrecision(precision)) + .length(Decimal.minBytesForPrecision(precision)) .named(field.name) // =================================== @@ -555,7 +554,7 @@ class SparkToParquetSchemaConverter( convertField(field.copy(dataType = udt.sqlType)) case _ => - throw new AnalysisException(s"Unsupported data type $field.dataType") + throw new AnalysisException(s"Unsupported data type ${field.dataType.catalogString}") } } } @@ -584,23 +583,4 @@ private[sql] object ParquetSchemaConverter { throw new AnalysisException(message) } } - - private def computeMinBytesForPrecision(precision : Int) : Int = { - var numBytes = 1 - while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { - numBytes += 1 - } - numBytes - } - - // Returns the minimum number of bytes needed to store a decimal with a given `precision`. - val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - - // Max precision of a decimal value stored in `numBytes` bytes - def maxPrecisionForBytes(numBytes: Int): Int = { - Math.round( // convert double to long - Math.floor(Math.log10( // number of base-10 digits - Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes - .asInstanceOf[Int] - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala new file mode 100644 index 0000000000000..6a46b5f8edc54 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectionOverSchema, SelectedField} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} + +/** + * Prunes unnecessary Parquet columns given a [[PhysicalOperation]] over a + * [[ParquetRelation]]. By "Parquet column", we mean a column as defined in the + * Parquet format. In Spark SQL, a root-level Parquet column corresponds to a + * SQL column, and a nested Parquet column corresponds to a [[StructField]]. + */ +private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = + if (SQLConf.get.nestedSchemaPruningEnabled) { + apply0(plan) + } else { + plan + } + + private def apply0(plan: LogicalPlan): LogicalPlan = + plan transformDown { + case op @ PhysicalOperation(projects, filters, + l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) + if canPruneRelation(hadoopFsRelation) => + val (normalizedProjects, normalizedFilters) = + normalizeAttributeRefNames(l, projects, filters) + val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters) + + // If requestedRootFields includes a nested field, continue. Otherwise, + // return op + if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) { + val dataSchema = hadoopFsRelation.dataSchema + val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields) + + // If the data schema is different from the pruned data schema, continue. Otherwise, + // return op. We effect this comparison by counting the number of "leaf" fields in + // each schemata, assuming the fields in prunedDataSchema are a subset of the fields + // in dataSchema. + if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { + val prunedParquetRelation = + hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession) + + val prunedRelation = buildPrunedRelation(l, prunedParquetRelation) + val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) + + buildNewProjection(normalizedProjects, normalizedFilters, prunedRelation, + projectionOverSchema) + } else { + op + } + } else { + op + } + } + + /** + * Checks to see if the given relation is Parquet and can be pruned. + */ + private def canPruneRelation(fsRelation: HadoopFsRelation) = + fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] + + /** + * Normalizes the names of the attribute references in the given projects and filters to reflect + * the names in the given logical relation. This makes it possible to compare attributes and + * fields by name. Returns a tuple with the normalized projects and filters, respectively. + */ + private def normalizeAttributeRefNames( + logicalRelation: LogicalRelation, + projects: Seq[NamedExpression], + filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = { + val normalizedAttNameMap = logicalRelation.output.map(att => (att.exprId, att.name)).toMap + val normalizedProjects = projects.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }).map { case expr: NamedExpression => expr } + val normalizedFilters = filters.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }) + (normalizedProjects, normalizedFilters) + } + + /** + * Returns the set of fields from the Parquet file that the query plan needs. + */ + private def identifyRootFields(projects: Seq[NamedExpression], filters: Seq[Expression]) = { + val projectionRootFields = projects.flatMap(getRootFields) + val filterRootFields = filters.flatMap(getRootFields) + + (projectionRootFields ++ filterRootFields).distinct + } + + /** + * Builds the new output [[Project]] Spark SQL operator that has the pruned output relation. + */ + private def buildNewProjection( + projects: Seq[NamedExpression], filters: Seq[Expression], prunedRelation: LogicalRelation, + projectionOverSchema: ProjectionOverSchema) = { + // Construct a new target for our projection by rewriting and + // including the original filters where available + val projectionChild = + if (filters.nonEmpty) { + val projectedFilters = filters.map(_.transformDown { + case projectionOverSchema(expr) => expr + }) + val newFilterCondition = projectedFilters.reduce(And) + Filter(newFilterCondition, prunedRelation) + } else { + prunedRelation + } + + // Construct the new projections of our Project by + // rewriting the original projections + val newProjects = projects.map(_.transformDown { + case projectionOverSchema(expr) => expr + }).map { case expr: NamedExpression => expr } + + if (log.isDebugEnabled) { + logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}") + } + + Project(newProjects, projectionChild) + } + + /** + * Filters the schema from the given file by the requested fields. + * Schema field ordering from the file is preserved. + */ + private def pruneDataSchema( + fileDataSchema: StructType, + requestedRootFields: Seq[RootField]) = { + // Merge the requested root fields into a single schema. Note the ordering of the fields + // in the resulting schema may differ from their ordering in the logical relation's + // original schema + val mergedSchema = requestedRootFields + .map { case RootField(field, _) => StructType(Array(field)) } + .reduceLeft(_ merge _) + val dataSchemaFieldNames = fileDataSchema.fieldNames.toSet + val mergedDataSchema = + StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name))) + // Sort the fields of mergedDataSchema according to their order in dataSchema, + // recursively. This makes mergedDataSchema a pruned schema of dataSchema + sortLeftFieldsByRight(mergedDataSchema, fileDataSchema).asInstanceOf[StructType] + } + + /** + * Builds a pruned logical relation from the output of the output relation and the schema of the + * pruned base relation. + */ + private def buildPrunedRelation( + outputRelation: LogicalRelation, + prunedBaseRelation: HadoopFsRelation) = { + // We need to replace the expression ids of the pruned relation output attributes + // with the expression ids of the original relation output attributes so that + // references to the original relation's output are not broken + val outputIdMap = outputRelation.output.map(att => (att.name, att.exprId)).toMap + val prunedRelationOutput = + prunedBaseRelation + .schema + .toAttributes + .map { + case att if outputIdMap.contains(att.name) => + att.withExprId(outputIdMap(att.name)) + case att => att + } + outputRelation.copy(relation = prunedBaseRelation, output = prunedRelationOutput) + } + + /** + * Gets the root (aka top-level, no-parent) [[StructField]]s for the given [[Expression]]. + * When expr is an [[Attribute]], construct a field around it and indicate that that + * field was derived from an attribute. + */ + private def getRootFields(expr: Expression): Seq[RootField] = { + expr match { + case att: Attribute => + RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil + case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil + case _ => + expr.children.flatMap(getRootFields) + } + } + + /** + * Counts the "leaf" fields of the given dataType. Informally, this is the + * number of fields of non-complex data type in the tree representation of + * [[DataType]]. + */ + private def countLeaves(dataType: DataType): Int = { + dataType match { + case array: ArrayType => countLeaves(array.elementType) + case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType) + case struct: StructType => + struct.map(field => countLeaves(field.dataType)).sum + case _ => 1 + } + } + + /** + * Sorts the fields and descendant fields of structs in left according to their order in + * right. This function assumes that the fields of left are a subset of the fields of + * right, recursively. That is, left is a "subschema" of right, ignoring order of + * fields. + */ + private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) => + ArrayType( + sortLeftFieldsByRight(leftElementType, rightElementType), + containsNull) + case (MapType(leftKeyType, leftValueType, containsNull), + MapType(rightKeyType, rightValueType, _)) => + MapType( + sortLeftFieldsByRight(leftKeyType, rightKeyType), + sortLeftFieldsByRight(leftValueType, rightValueType), + containsNull) + case (leftStruct: StructType, rightStruct: StructType) => + val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains) + val sortedLeftFields = filteredRightFieldNames.map { fieldName => + val leftFieldType = leftStruct(fieldName).dataType + val rightFieldType = rightStruct(fieldName).dataType + val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) + StructField(fieldName, sortedLeftFieldType) + } + StructType(sortedLeftFields) + case _ => left + } + + /** + * A "root" schema field (aka top-level, no-parent) and whether it was derived from + * an attribute or had a proper child. + */ + private case class RootField(field: StructField, derivedFromAtt: Boolean) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index af4e1433c876f..b40b8c2e61f33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -33,7 +33,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.minBytesForPrecision import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -73,7 +72,8 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit private val timestampBuffer = new Array[Byte](12) // Reusable byte array used to write decimal values - private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) + private val decimalBuffer = + new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION)) override def init(configuration: Configuration): WriteContext = { val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) @@ -212,7 +212,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit precision <= DecimalType.MAX_PRECISION, s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}") - val numBytes = minBytesForPrecision(precision) + val numBytes = Decimal.minBytesForPrecision(precision) val int32Writer = (row: SpecializedGetters, ordinal: Int) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0dea767840ed3..949aa665527ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -39,7 +39,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( @@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { case _: ClassNotFoundException => u case e: Exception => // the provider is valid, but failed to create a logical plan - u.failAnalysis(e.getMessage) + u.failAnalysis(e.getMessage, e) } } } @@ -73,7 +73,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // catalog is a def and not a val/lazy val as the latter would introduce a circular reference private def catalog = sparkSession.sessionState.catalog - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // When we CREATE TABLE without specifying the table schema, we should fail the query if // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, @@ -281,7 +281,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi schema.filter(f => normalizedPartitionCols.contains(f.name)).map(_.dataType).foreach { case _: AtomicType => // OK - case other => failAnalysis(s"Cannot use ${other.simpleString} for partition column") + case other => failAnalysis(s"Cannot use ${other.catalogString} for partition column") } normalizedPartitionCols @@ -307,7 +307,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi normalizedBucketSpec.sortColumnNames.map(schema(_)).map(_.dataType).foreach { case dt if RowOrdering.isOrderable(dt) => // OK - case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column") + case other => failAnalysis(s"Cannot use ${other.catalogString} for sorting column") } Some(normalizedBucketSpec) @@ -365,7 +365,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: HiveTableRelation => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index e93908da43535..268297148b522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType, StructType} import org.apache.spark.util.SerializableConfiguration /** @@ -47,11 +47,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( s"Text data source supports only a single column, and you have ${schema.size} columns.") } - val tpe = schema(0).dataType - if (tpe != StringType) { - throw new AnalysisException( - s"Text data source supports only a string column, but you have ${tpe.simpleString}.") - } } override def isSplitable( @@ -125,7 +120,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } else { new HadoopFileWholeTextReader(file, confValue) } - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close())) if (requiredSchema.isEmpty) { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) @@ -141,6 +136,9 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = + dataType == StringType } class TextOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 5c1a35434f7b5..e4e201995faa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.text -import java.nio.charset.StandardCharsets +import java.nio.charset.{Charset, StandardCharsets} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} @@ -41,13 +41,18 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean - private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => - require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") - sep + val encoding: Option[String] = parameters.get(ENCODING) + + val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + + lineSep } + // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8)) + } val lineSeparatorInWrite: Array[Byte] = lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } @@ -55,5 +60,6 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val ENCODING = "encoding" val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 017a6737161a6..33079d5912506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -30,8 +30,8 @@ class DataSourcePartitioning( override val numPartitions: Int = partitioning.numPartitions() - override def satisfies(required: physical.Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: physical.Distribution): Boolean = { + super.satisfies0(required) || { required match { case d: physical.ClusteredDistribution if isCandidate(d.clustering) => val attrs = d.clustering.map(_.asInstanceOf[Attribute]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index f85971be394b1..f62f7349d1da7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,31 +17,44 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - -import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} +import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) +class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) extends Partition with Serializable -class DataSourceRDD[T: ClassTag]( +// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for +// columnar scan. +class DataSourceRDD( sc: SparkContext, - @transient private val readerFactories: Seq[DataReaderFactory[T]]) - extends RDD[T](sc, Nil) { + @transient private val inputPartitions: Seq[InputPartition], + partitionReaderFactory: PartitionReaderFactory, + columnarReads: Boolean) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) + inputPartitions.zipWithIndex.map { + case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() - context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[T] { + private def castPartition(split: Partition): DataSourceRDDPartition = split match { + case p: DataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a DataSourceRDDPartition: $split") + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val inputPartition = castPartition(split).inputPartition + val reader: PartitionReader[_] = if (columnarReads) { + partitionReaderFactory.createColumnarReader(inputPartition) + } else { + partitionReaderFactory.createReader(inputPartition) + } + + context.addTaskCompletionListener[Unit](_ => reader.close()) + val iter = new Iterator[Any] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +64,7 @@ class DataSourceRDD[T: ClassTag]( valuePrepared } - override def next(): T = { + override def next(): Any = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -59,10 +72,11 @@ class DataSourceRDD[T: ClassTag]( reader.get() } } - new InterruptibleIterator(context, iter) + // TODO: SPARK-25083 remove the type erasure hack in data source scan + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() + castPartition(split).inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2b282ffae2390..f7e29593a6353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,89 +17,59 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + import scala.collection.JavaConverters._ -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, BatchWriteSupportProvider, DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.reader.{BatchReadSupport, ReadSupport, ScanConfigBuilder, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport import org.apache.spark.sql.types.StructType +/** + * A logical plan representing a data source v2 scan. + * + * @param source An instance of a [[DataSourceV2]] implementation. + * @param options The options for this scan. Used to create fresh [[BatchWriteSupport]]. + * @param userSpecifiedSchema The user-specified schema for this scan. + */ case class DataSourceV2Relation( source: DataSourceV2, + readSupport: BatchReadSupport, + output: Seq[AttributeReference], options: Map[String, String], - projection: Seq[AttributeReference], - filters: Option[Seq[Expression]] = None, + tableIdent: Option[TableIdentifier] = None, userSpecifiedSchema: Option[StructType] = None) - extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + extends LeafNode with MultiInstanceRelation with NamedRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = "RelationV2 " + metadataString - - override lazy val schema: StructType = reader.readSchema() - - override lazy val output: Seq[AttributeReference] = { - // use the projection attributes to avoid assigning new ids. fields that are not projected - // will be assigned new ids, which is okay because they are not projected. - val attrMap = projection.map(a => a.name -> a).toMap - schema.map(f => attrMap.getOrElse(f.name, - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) - } - - private lazy val v2Options: DataSourceOptions = makeV2Options(options) - - lazy val ( - reader: DataSourceReader, - unsupportedFilters: Seq[Expression], - pushedFilters: Seq[Expression]) = { - val newReader = userSpecifiedSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) - } - - DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - - val (remainingFilters, pushedFilters) = filters match { - case Some(filterSeq) => - DataSourceV2Relation.pushFilters(newReader, filterSeq) - case _ => - (Nil, Nil) - } - - (newReader, remainingFilters, pushedFilters) + override def name: String = { + tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown") } - override def doCanonicalize(): LogicalPlan = { - val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + override def pushedFilters: Seq[Expression] = Seq.empty - // override output with canonicalized output to avoid attempting to configure a reader - val canonicalOutput: Seq[AttributeReference] = this.output - .map(a => QueryPlan.normalizeExprId(a, projection)) + override def simpleString: String = "RelationV2 " + metadataString - new DataSourceV2Relation(c.source, c.options, c.projection) { - override lazy val output: Seq[AttributeReference] = canonicalOutput - } - } + def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = readSupport match { case r: SupportsReportStatistics => - Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + val statistics = r.estimateStatistics(readSupport.newScanConfigBuilder().build()) + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } override def newInstance(): DataSourceV2Relation = { - // projection is used to maintain id assignment. - // if projection is not set, use output so the copy is not equal to the original - copy(projection = projection.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -114,19 +84,23 @@ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], source: DataSourceV2, options: Map[String, String], - reader: DataSourceReader) + readSupport: ReadSupport, + scanConfigBuilder: ScanConfigBuilder) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { override def isStreaming: Boolean = true override def simpleString: String = "Streaming RelationV2 " + metadataString + override def pushedFilters: Seq[Expression] = Nil + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: StreamingDataSourceV2Relation => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options case _ => false } @@ -134,9 +108,10 @@ case class StreamingDataSourceV2Relation( Seq(output, source, options).hashCode() } - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = readSupport match { case r: SupportsReportStatistics => - Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + val statistics = r.estimateStatistics(scanConfigBuilder.build()) + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -144,28 +119,21 @@ case class StreamingDataSourceV2Relation( object DataSourceV2Relation { private implicit class SourceHelpers(source: DataSourceV2) { - def asReadSupport: ReadSupport = { + def asReadSupportProvider: BatchReadSupportProvider = { source match { - case support: ReadSupport => - support - case _: ReadSupportWithSchema => - // this method is only called if there is no user-supplied schema. if there is no - // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. - throw new AnalysisException(s"Data source requires a user-supplied schema: $name") + case provider: BatchReadSupportProvider => + provider case _ => throw new AnalysisException(s"Data source is not readable: $name") } } - def asReadSupportWithSchema: ReadSupportWithSchema = { + def asWriteSupportProvider: BatchWriteSupportProvider = { source match { - case support: ReadSupportWithSchema => - support - case _: ReadSupport => - throw new AnalysisException( - s"Data source does not support user-supplied schema: $name") + case provider: BatchWriteSupportProvider => + provider case _ => - throw new AnalysisException(s"Data source is not readable: $name") + throw new AnalysisException(s"Data source is not writable: $name") } } @@ -177,73 +145,45 @@ object DataSourceV2Relation { source.getClass.getSimpleName } } - } - private def makeV2Options(options: Map[String, String]): DataSourceOptions = { - new DataSourceOptions(options.asJava) - } + def createReadSupport( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): BatchReadSupport = { + val v2Options = new DataSourceOptions(options.asJava) + userSpecifiedSchema match { + case Some(s) => + asReadSupportProvider.createBatchReadSupport(s, v2Options) + case _ => + asReadSupportProvider.createBatchReadSupport(v2Options) + } + } - private def schema( - source: DataSourceV2, - v2Options: DataSourceOptions, - userSchema: Option[StructType]): StructType = { - val reader = userSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) + def createWriteSupport( + options: Map[String, String], + schema: StructType): BatchWriteSupport = { + asWriteSupportProvider.createBatchWriteSupport( + UUID.randomUUID().toString, + schema, + SaveMode.Append, + new DataSourceOptions(options.asJava)).get } - reader.readSchema() } def create( source: DataSourceV2, options: Map[String, String], - filters: Option[Seq[Expression]] = None, + tableIdent: Option[TableIdentifier] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) + val readSupport = source.createReadSupport(options, userSpecifiedSchema) + val output = readSupport.fullSchema().toAttributes + val ident = tableIdent.orElse(tableFromOptions(options)) + DataSourceV2Relation( + source, readSupport, output, options, ident, userSpecifiedSchema) } - private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { - reader match { - case projectionSupport: SupportsPushDownRequiredColumns => - projectionSupport.pruneColumns(struct) - case _ => - } - } - - private def pushFilters( - reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { - case catalystFilterSupport: SupportsPushDownCatalystFilters => - ( - catalystFilterSupport.pushCatalystFilters(filters.toArray), - catalystFilterSupport.pushedCatalystFilters() - ) - - case filterSupport: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet - val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => - unhandledFilters.contains(f) - } - - (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) - - case _ => (filters, Nil) - } + private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { + options + .get(DataSourceOptions.TABLE_KEY) + .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3a5e7bf89e142..04a97735d024d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition @@ -30,9 +26,7 @@ import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeSta import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} /** * Physical plan node for scanning data from a data source. @@ -41,7 +35,9 @@ case class DataSourceV2ScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], - @transient reader: DataSourceReader) + @transient pushedFilters: Seq[Expression], + @transient readSupport: ReadSupport, + @transient scanConfig: ScanConfig) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { override def simpleString: String = "ScanV2 " + metadataString @@ -49,7 +45,8 @@ case class DataSourceV2ScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: DataSourceV2ScanExec => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options case _ => false } @@ -57,61 +54,58 @@ case class DataSourceV2ScanExec( Seq(output, source, options).hashCode() } - override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => - SinglePartition - - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => - SinglePartition - - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + override def outputPartitioning: physical.Partitioning = readSupport match { + case _ if partitions.length == 1 => SinglePartition case s: SupportsReportPartitioning => new DataSourcePartitioning( - s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name))) case _ => super.outputPartitioning } - private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala - case _ => - reader.createDataReaderFactories().asScala.map { - new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] - } + private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) + + private lazy val readerFactory = readSupport match { + case r: BatchReadSupport => r.createReaderFactory(scanConfig) + case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) + case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) + case _ => throw new IllegalStateException("unknown read support: " + readSupport) } - private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - assert(!reader.isInstanceOf[ContinuousReader], - "continuous stream reader does not support columnar read yet.") - r.createBatchDataReaderFactories().asScala + // TODO: clean this up when we have dedicated scan plan for continuous streaming. + override val supportsBatch: Boolean = { + require(partitions.forall(readerFactory.supportColumnarReads) || + !partitions.exists(readerFactory.supportColumnarReads), + "Cannot mix row-based and columnar input partitions.") + + partitions.exists(readerFactory.supportColumnarReads) } - private lazy val inputRDD: RDD[InternalRow] = reader match { - case _: ContinuousReader => + private lazy val inputRDD: RDD[InternalRow] = readSupport match { + case _: ContinuousReadSupport => + assert(!supportsBatch, + "continuous stream reader does not support columnar read yet.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size)) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) - .asInstanceOf[RDD[InternalRow]] - - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + .askSync[Unit](SetReaderPartitions(partitions.size)) + new ContinuousDataSourceRDD( + sparkContext, + sqlContext.conf.continuousStreamingExecutorQueueSize, + sqlContext.conf.continuousStreamingExecutorPollIntervalMs, + partitions, + schema, + readerFactory.asInstanceOf[ContinuousPartitionReaderFactory]) case _ => - new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD( + sparkContext, partitions, readerFactory.asInstanceOf[PartitionReaderFactory], supportsBatch) } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - override val supportsBatch: Boolean = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => true - case _ => false - } - override protected def needsUnsafeRowConversion: Boolean = false override protected def doExecute(): RDD[InternalRow] = { @@ -126,24 +120,3 @@ case class DataSourceV2ScanExec( } } } - -class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) - extends DataReaderFactory[UnsafeRow] { - - override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations - - override def createDataReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader( - rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) - } -} - -class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) - extends DataReader[UnsafeRow] { - - override def next: Boolean = rowReader.next - - override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow] - - override def close(): Unit = rowReader.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1ac9572de6412..9a3109e7c199e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,21 +17,148 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import scala.collection.mutable + +import org.apache.spark.sql.{sources, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport object DataSourceV2Strategy extends Strategy { + + /** + * Pushes down filters to the data source reader + * + * @return pushed filter and post-scan filters. + */ + private def pushFilters( + configBuilder: ScanConfigBuilder, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + configBuilder match { + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) + .map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + (pushedFilters, untranslatableExprs ++ postScanFilters) + + case _ => (Nil, filters) + } + } + + /** + * Applies column pruning to the data source, w.r.t. the references of the given expressions. + * + * @return the created `ScanConfig`(since column pruning is the last step of operator pushdown), + * and new output attributes after column pruning. + */ + // TODO: nested column pruning. + private def pruneColumns( + configBuilder: ScanConfigBuilder, + relation: DataSourceV2Relation, + exprs: Seq[Expression]): (ScanConfig, Seq[AttributeReference]) = { + configBuilder match { + case r: SupportsPushDownRequiredColumns => + val requiredColumns = AttributeSet(exprs.flatMap(_.references)) + val neededOutput = relation.output.filter(requiredColumns.contains) + if (neededOutput != relation.output) { + r.pruneColumns(neededOutput.toStructType) + val config = r.build() + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + config -> config.readSchema().toAttributes.map { + // We have to keep the attribute id during transformation. + a => a.withExprId(nameToAttr(a.name).exprId) + } + } else { + r.build() -> relation.output + } + + case _ => configBuilder.build() -> relation.output + } + } + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val configBuilder = relation.readSupport.newScanConfigBuilder() + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFilters) = pushFilters(configBuilder, filters) + val (config, output) = pruneColumns(configBuilder, relation, project ++ postScanFilters) + logInfo( + s""" + |Pushing operators to ${relation.source.getClass} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val scan = DataSourceV2ScanExec( + output, + relation.source, + relation.options, + pushedFilters, + relation.readSupport, + config) + + val filterCondition = postScanFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + // always add the projection, which will produce unsafe rows required by some operators + ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + // TODO: support operator pushdown for streaming data sources. + val scanConfig = r.scanConfigBuilder.build() + // ensure there is a projection, which will produce unsafe rows required by some operators + ProjectExec(r.output, + DataSourceV2ScanExec( + r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case AppendData(r: DataSourceV2Relation, query, _) => + WriteToDataSourceV2Exec(r.newWriteSupport(), planLater(query)) :: Nil + + case WriteToContinuousDataSource(writer, query) => + WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + + case Repartition(1, false, child) => + val isContinuous = child.find { + case s: StreamingDataSourceV2Relation => s.readSupport.isInstanceOf[ContinuousReadSupport] + case _ => false + }.isDefined + + if (isContinuous) { + ContinuousCoalesceExec(1, planLater(child)) :: Nil + } else { + Nil + } + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index aed55a429bfd7..97e6c6d702acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.util.Utils /** @@ -49,27 +47,22 @@ trait DataSourceV2StringFormat { def options: Map[String, String] /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. + * The filters which have been pushed to the data source. */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } + def pushedFilters: Seq[Expression] private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() - case _ => source.getClass.getSimpleName.stripSuffix("$") + // source.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + case _ => Utils.getSimpleName(source.getClass) } def metadataString: String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) { - entries += "Filters" -> filters.mkString("[", ", ", "]") + if (pushedFilters.nonEmpty) { + entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") } // TODO: we should only display some standard options like path, table, etc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 5267f5f1580c3..e9cc3991155c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,6 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} private[sql] object DataSourceV2Utils extends Logging { @@ -55,4 +56,12 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } + + def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { + val name = ds match { + case register: DataSourceRegister => register.shortName() + case _ => ds.getClass.getName + } + throw new UnsupportedOperationException(name + " source does not support user-specified schema") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala deleted file mode 100644 index f23d228567241..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply( - plan: LogicalPlan): LogicalPlan = plan transformUp { - // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => - // merge the filters - val filters = relation.filters match { - case Some(existing) => - existing ++ newFilters - case _ => - newFilters - } - - val projectAttrs = project.map(_.toAttribute) - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(projectAttrs) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - projectAttrs - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - - val newRelation = relation.copy( - projection = projection.asInstanceOf[Seq[AttributeReference]], - filters = Some(filters)) - - // Add a Filter for any filters that could not be pushed - val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) - val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) - - // Add a Project to ensure the output matches the required projection - if (newRelation.output != projectAttrs) { - Project(project, filtered) - } else { - filtered - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala deleted file mode 100644 index e80b44c1cdc66..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import scala.util.control.NonFatal - -import org.apache.spark.{SparkEnv, SparkException, TaskContext} -import org.apache.spark.executor.CommitDeniedException -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} -import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils - -/** - * The logical plan for writing data into data source v2. - */ -case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { - override def children: Seq[LogicalPlan] = Seq(query) - override def output: Seq[Attribute] = Nil -} - -/** - * The physical plan for writing data into data source v2. - */ -case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { - override def children: Seq[SparkPlan] = Seq(query) - override def output: Seq[Attribute] = Nil - - override protected def doExecute(): RDD[InternalRow] = { - val writeTask = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - - val useCommitCoordinator = writer.useCommitCoordinator - val rdd = query.execute() - val messages = new Array[WriterCommitMessage](rdd.partitions.length) - - logInfo(s"Start processing data source writer: $writer. " + - s"The input RDD has ${messages.length} partitions.") - - try { - val runTask = writer match { - // This case means that we're doing continuous processing. In microbatch streaming, the - // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. - case w: StreamWriter => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) - .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) - - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.runContinuous(writeTask, context, iter) - case _ => - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) - } - - sparkContext.runJob( - rdd, - runTask, - rdd.partitions.indices, - (index, message: WriterCommitMessage) => { - messages(index) = message - writer.onDataWriterCommit(message) - } - ) - - if (!writer.isInstanceOf[StreamWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } - } catch { - case _: InterruptedException if writer.isInstanceOf[StreamWriter] => - // Interruption is how continuous queries are ended, so accept and ignore the exception. - case cause: Throwable => - logError(s"Data source writer $writer is aborting.") - try { - writer.abort(messages) - } catch { - case t: Throwable => - logError(s"Data source writer $writer failed to abort.") - cause.addSuppressed(t) - throw new SparkException("Writing job failed.", cause) - } - logError(s"Data source writer $writer aborted.") - cause match { - // Do not wrap interruption exceptions that will be handled by streaming specially. - case _ if StreamExecution.isInterruptionException(cause) => throw cause - // Only wrap non fatal exceptions. - case NonFatal(e) => throw new SparkException("Writing job aborted.", e) - case _ => throw cause - } - } - - sparkContext.emptyRDD - } -} - -object DataWritingSparkTask extends Logging { - def run( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow], - useCommitCoordinator: Boolean): WriterCommitMessage = { - val stageId = context.stageId() - val partId = context.partitionId() - val attemptId = context.attemptNumber() - val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) - - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - iter.foreach(dataWriter.write) - - val msg = if (useCommitCoordinator) { - val coordinator = SparkEnv.get.outputCommitCoordinator - val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) - if (commitAuthorized) { - logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") - dataWriter.commit() - } else { - val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" - logInfo(message) - // throwing CommitDeniedException will trigger the catch block for abort - throw new CommitDeniedException(message, stageId, partId, attemptId) - } - - } else { - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - dataWriter.commit() - } - - logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") - - msg - - })(catchBlock = { - // If there is an error, abort this writer - logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") - dataWriter.abort() - logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") - }) - } - - def runContinuous( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - val currentMsg: WriterCommitMessage = null - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - - currentMsg - } -} - -class InternalRowDataWriterFactory( - rowWriterFactory: DataWriterFactory[Row], - schema: StructType) extends DataWriterFactory[InternalRow] { - - override def createDataWriter( - partitionId: Int, - attemptNumber: Int, - epochId: Long): DataWriter[InternalRow] = { - new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), - RowEncoder.apply(schema).resolveAndBind()) - } -} - -class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) - extends DataWriter[InternalRow] { - - override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) - - override def commit(): WriterCommitMessage = rowWriter.commit() - - override def abort(): Unit = rowWriter.abort() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala new file mode 100644 index 0000000000000..c3f7b690ef636 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.executor.CommitDeniedException +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.util.Utils + +/** + * Deprecated logical plan for writing data into data source v2. This is being replaced by more + * specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]]. + */ +@deprecated("Use specific logical plans like AppendData instead", "2.4.0") +case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPlan) + extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} + +/** + * The physical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) + extends SparkPlan { + + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writerFactory = writeSupport.createBatchWriterFactory() + val useCommitCoordinator = writeSupport.useCommitCoordinator + val rdd = query.execute() + val messages = new Array[WriterCommitMessage](rdd.partitions.length) + + logInfo(s"Start processing data source write support: $writeSupport. " + + s"The input RDD has ${messages.length} partitions.") + + try { + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), + rdd.partitions.indices, + (index, message: WriterCommitMessage) => { + messages(index) = message + writeSupport.onDataWriterCommit(message) + } + ) + + logInfo(s"Data source write support $writeSupport is committing.") + writeSupport.commit(messages) + logInfo(s"Data source write support $writeSupport committed.") + } catch { + case cause: Throwable => + logError(s"Data source write support $writeSupport is aborting.") + try { + writeSupport.abort(messages) + } catch { + case t: Throwable => + logError(s"Data source write support $writeSupport failed to abort.") + cause.addSuppressed(t) + throw new SparkException("Writing job failed.", cause) + } + logError(s"Data source write support $writeSupport aborted.") + cause match { + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } + } + + sparkContext.emptyRDD + } +} + +object DataWritingSparkTask extends Logging { + def run( + writerFactory: DataWriterFactory, + context: TaskContext, + iter: Iterator[InternalRow], + useCommitCoordinator: Boolean): WriterCommitMessage = { + val stageId = context.stageId() + val stageAttempt = context.stageAttemptNumber() + val partId = context.partitionId() + val taskId = context.taskAttemptId() + val attemptId = context.attemptNumber() + val dataWriter = writerFactory.createWriter(partId, taskId) + + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + while (iter.hasNext) { + dataWriter.write(iter.next()) + } + + val msg = if (useCommitCoordinator) { + val coordinator = SparkEnv.get.outputCommitCoordinator + val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) + if (commitAuthorized) { + logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") + dataWriter.commit() + } else { + val message = s"Commit denied for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)" + logInfo(message) + // throwing CommitDeniedException will trigger the catch block for abort + throw new CommitDeniedException(message, stageId, partId, attemptId) + } + + } else { + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + dataWriter.commit() + } + + logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") + + msg + + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") + dataWriter.abort() + logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") + }) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index a717cbd4a7df9..366e1fe6a4aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming.continuous.WriteToContinuousDataSourceExec +import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.util.{AccumulatorV2, LongAccumulator} /** @@ -40,6 +43,16 @@ import org.apache.spark.util.{AccumulatorV2, LongAccumulator} * sql("SELECT 1").debug() * sql("SELECT 1").debugCodegen() * }}} + * + * or for streaming case (structured streaming): + * {{{ + * import org.apache.spark.sql.execution.debug._ + * val query = df.writeStream.<...>.start() + * query.debugCodegen() + * }}} + * + * Note that debug in structured streaming is not supported, because it doesn't make sense for + * streaming to execute batch once while main query is running concurrently. */ package object debug { @@ -88,14 +101,50 @@ package object debug { } } + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan into one String + * + * @param query the streaming query for codegen + * @return single String containing all WholeStageCodegen subtrees and corresponding codegen + */ + def codegenString(query: StreamingQuery): String = { + val w = asStreamExecution(query) + if (w.lastExecution != null) { + codegenString(w.lastExecution.executedPlan) + } else { + "No physical plan. Waiting for data." + } + } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @param query the streaming query for codegen + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenStringSeq(query: StreamingQuery): Seq[(String, String)] = { + val w = asStreamExecution(query) + if (w.lastExecution != null) { + codegenStringSeq(w.lastExecution.executedPlan) + } else { + Seq.empty + } + } + + private def asStreamExecution(query: StreamingQuery): StreamExecution = query match { + case wrapper: StreamingQueryWrapper => wrapper.streamingQuery + case q: StreamExecution => q + case _ => throw new IllegalArgumentException("Parameter should be an instance of " + + "StreamExecution!") + } + /** * Augments [[Dataset]]s with debug methods. */ implicit class DebugQuery(query: Dataset[_]) extends Logging { def debug(): Unit = { - val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() - val debugPlan = plan transform { + val debugPlan = query.queryExecution.executedPlan transform { case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => visited += new TreeNodeRef(s) DebugExec(s) @@ -116,6 +165,12 @@ package object debug { } } + implicit class DebugStreamQuery(query: StreamingQuery) extends Logging { + def debugCodegen(): Unit = { + debugPrint(codegenString(query)) + } + } + case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..a80673c705f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution.exchange +import java.util.concurrent.TimeoutException + import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} import org.apache.spark.launcher.SparkLauncher @@ -30,7 +33,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SparkFatalException, ThreadUtils} /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of @@ -69,7 +72,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types @@ -111,12 +114,18 @@ case class BroadcastExchangeExec( SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. case oe: OutOfMemoryError => - throw new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + + throw new SparkFatalException( + new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + s"all worker nodes. As a workaround, you can either disable broadcast by setting " + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " + s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value") - .initCause(oe.getCause) + .initCause(oe.getCause)) + case e if !NonFatal(e) => + throw new SparkFatalException(e) } } }(BroadcastExchangeExec.executionContext) @@ -133,7 +142,16 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + try { + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex) + throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " + + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1", + ex) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c5470..d2d5011bbcb97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -81,7 +82,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (adaptiveExecutionEnabled && supportsCoordinator) { val coordinator = new ExchangeCoordinator( - children.length, targetPostShuffleInputSize, minNumPostShufflePartitions) children.zip(requiredChildDistributions).map { @@ -227,9 +227,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() + val pickedIndexes = mutable.Set[Int]() + val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + val index = keysAndIndexes.find { case (e, idx) => + // As we may have the same key used many times, we need to filter out its occurrence we + // have already used. + e.semanticEquals(expression) && !pickedIndexes.contains(idx) + }.map(_._2).get + pickedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) @@ -270,14 +277,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * partitioning of the join nodes' children. */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - plan.transformUp { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, - right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - + plan match { case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) @@ -288,6 +288,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 09f79a2de0ba0..1a5b7599bb7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan } override def outputPartitioning: Partitioning = child.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 78f11ca8d8c78..f5d93ee5fa914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -83,7 +83,6 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MB) */ class ExchangeCoordinator( - numExchanges: Int, advisoryTargetPostShuffleInputSize: Long, minNumPostShufflePartitions: Option[Int] = None) extends Logging { @@ -91,8 +90,14 @@ class ExchangeCoordinator( // The registered Exchange operators. private[this] val exchanges = ArrayBuffer[ShuffleExchangeExec]() + // `lazy val` is used here so that we could notice the wrong use of this class, e.g., all the + // exchanges should be registered before `postShuffleRDD` called first time. If a new exchange is + // registered after the `postShuffleRDD` call, `assert(exchanges.length == numExchanges)` fails + // in `doEstimationIfNecessary`. + private[this] lazy val numExchanges = exchanges.size + // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = + private[this] lazy val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. @@ -117,10 +122,6 @@ class ExchangeCoordinator( */ def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { - // If we have mapOutputStatistics.length < numExchange, it is because we do not submit - // a stage when the number of partitions of this dependency is 0. - assert(mapOutputStatistics.length <= numExchanges) - // If minNumPostShufflePartitions is defined, it is possible that we need to use a // value less than advisoryTargetPostShuffleInputSize as the target input size of // a post shuffle task. @@ -228,20 +229,24 @@ class ExchangeCoordinator( j += 1 } + // If we have mapOutputStatistics.length < numExchange, it is because we do not submit + // a stage when the number of partitions of this dependency is 0. + assert(mapOutputStatistics.length <= numExchanges) + // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the // number of post-shuffle partitions. val partitionStartIndices = if (mapOutputStatistics.length == 0) { - None + Array.empty[Int] } else { - Some(estimatePartitionStartIndices(mapOutputStatistics)) + estimatePartitionStartIndices(mapOutputStatistics) } var k = 0 while (k < numExchanges) { val exchange = exchanges(k) val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) + exchange.preparePostShuffleRDD(shuffleDependencies(k), Some(partitionStartIndices)) newPostShuffleRDDs.put(exchange, rdd) k += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b89203719541b..50f10c31427d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -231,6 +231,11 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } + case l: LocalPartitioning => + new Partitioner { + override def numPartitions: Int = l.numPartitions + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -247,6 +252,9 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity + case _: LocalPartitioning => + val partitionId = TaskContext.get().partitionId() + _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 6fa716d9fadee..a6f3ea47c8492 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -183,7 +184,7 @@ case class BroadcastHashJoinExec( val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val javaType = CodeGenerator.javaType(a.dataType) - val code = s""" + val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { @@ -319,7 +320,7 @@ case class BroadcastHashJoinExec( |if (!$conditionPassed) { | $matched = null; | // reset the variables those are already evaluated. - | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} + | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} = true;").mkString("\n")} |} |$numOutput.add(1); |${consume(ctx, resultVars)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0396168d3f311..dab873bf9b9a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -214,7 +214,7 @@ trait HashJoin { } // At the end of the task, we update the avg hash probe. - TaskContext.get().addTaskCompletionListener(_ => + TaskContext.get().addTaskCompletionListener[Unit](_ => avgHashProbe.set(hashed.getAverageProbesPerLookup)) val resultProj = createResultProjection diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 1465346eb802d..86eb47a70f1ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap def append(key: Long, row: UnsafeRow): Unit = { val sizeInBytes = row.getSizeInBytes if (sizeInBytes >= (1 << SIZE_BITS)) { - sys.error("Does not support row that is larger than 256M") + throw new UnsupportedOperationException("Does not support row that is larger than 256M") } if (key < minKey) { @@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = key } - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) { - val used = page.length - if (used >= (1 << 30)) { - sys.error("Can not build a HashedRelation that is larger than 8G") - } - ensureAcquireMemory(used * 8L * 2) - val newPage = new Array[Long](used * 2) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - page = newPage - freeMemory(used * 8L) - } + grow(row.getSizeInBytes) // copy the bytes of UnsafeRow val offset = cursor @@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap growArray() } else if (numKeys > array.length / 2 * 0.75) { // The fill ratio should be less than 0.75 - sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys") + throw new UnsupportedOperationException( + "Cannot build HashedRelation with more than 1/3 billions unique keys") } } } else { @@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } + private def grow(inputRowSize: Int): Unit = { + // There is 8 bytes for the pointer to next value + val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.length) { + if (neededNumWords > (1 << 30)) { + throw new UnsupportedOperationException( + "Can not build a HashedRelation that is larger than 8G") + } + val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) + ensureAcquireMemory(newNumWords * 8L) + val newPage = new Array[Long](newNumWords.toInt) + Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, + cursor - Platform.LONG_ARRAY_OFFSET) + val used = page.length + page = newPage + freeMemory(used * 8L) + } + } + private def growArray(): Unit = { var old_array = array val n = array.length @@ -764,6 +772,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap array = readLongArray(readBuffer, length) val pageLength = readLong().toInt page = readLongArray(readBuffer, pageLength) + // Restore cursor variable to make this map able to be serialized again on executors. + cursor = pageLength * 8 + Platform.LONG_ARRAY_OFFSET } override def readExternal(in: ObjectInput): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 897a4dae39f32..2b59ed6e4d16b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -57,7 +57,7 @@ case class ShuffledHashJoinExec( buildTime += (System.nanoTime() - start) / 1000000 buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. - context.addTaskCompletionListener(_ => relation.close()) + context.addTaskCompletionListener[Unit](_ => relation.close()) relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d8261f0f33b61..f4b9d132122e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -521,7 +522,7 @@ case class SortMergeJoinExec( if (a.nullable) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin @@ -533,7 +534,7 @@ case class SortMergeJoinExec( (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), leftVarsDecl) } else { - val code = s"$value = $valueCode;" + val code = code"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 66bcda8913738..fb46970e38f3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -47,13 +47,16 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Helper trait which defines methods that are shared by both - * [[LocalLimitExec]] and [[GlobalLimitExec]]. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -trait BaseLimitExec extends UnaryExecNode with CodegenSupport { - val limit: Int +case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -93,25 +96,93 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Take the `limit` elements of the child output. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning -} -/** - * Take the first `limit` elements of the child's single output partition. - */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = { + val childRDD = child.execute() + val partitioner = LocalPartitioning(childRDD) + val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( + childRDD, child.output, partitioner, serializer) + val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) + submittedStageFuture.get().recordsByPartitionId.toSeq + } else { + Nil + } - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + // This is an optimization to evenly distribute limited rows across all partitions. + // When enabled, Spark goes to take rows at each partition repeatedly until reaching + // limit number. When disabled, Spark takes all rows at first partition, then rows + // at second partition ..., until reaching limit number. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit + + val shuffled = new ShuffledRowRDD(shuffleDependency) + + val sumOfOutput = numberOfOutput.sum + if (sumOfOutput <= limit) { + shuffled + } else if (!flatGlobalLimit) { + var numRowTaken = 0 + val takeAmounts = numberOfOutput.map { num => + if (numRowTaken + num < limit) { + numRowTaken += num.toInt + num.toInt + } else { + val toTake = limit - numRowTaken + numRowTaken += toTake + toTake + } + } + val broadMap = sparkContext.broadcast(takeAmounts) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } else { + // We try to evenly require the asked limit number of rows across all child rdd's partitions. + var rowsNeedToTake: Long = limit + val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L) + val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*) + + while (rowsNeedToTake > 0) { + val nonEmptyParts = remainingRowsByPartition.count(_ > 0) + // If the rows needed to take are less the number of non-empty partitions, take one row from + // each non-empty partitions until we reach `limit` rows. + // Otherwise, evenly divide the needed rows to each non-empty partitions. + val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts) + remainingRowsByPartition.zipWithIndex.foreach { case (num, index) => + // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during + // the traversal, so we need to add this check. + if (rowsNeedToTake > 0 && num > 0) { + if (num >= takePerPart) { + rowsNeedToTake -= takePerPart + takeAmountByPartition(index) += takePerPart + remainingRowsByPartition(index) -= takePerPart + } else { + rowsNeedToTake -= num + takeAmountByPartition(index) += num + remainingRowsByPartition(index) -= num + } + } + } + } + val broadMap = sparkContext.broadcast(takeAmountByPartition) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 77b907870d678..cbf707f4a9cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,7 +104,7 @@ object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): SQLMetric = { - // The final result of this metric in physical operator UI may looks like: + // The final result of this metric in physical operator UI may look like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) val acc = new SQLMetric(SIZE_METRIC, -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 8e01e8e56a5bd..2ab7240556aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -78,10 +79,8 @@ case class AggregateInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip @@ -124,7 +123,7 @@ case class AggregateInPandasExec( // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => queue.close() } @@ -135,10 +134,12 @@ case class AggregateInPandasExec( } val columnarBatchIter = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(projectedRowIter, context.partitionId(), context) + pyFuncs, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c4de214679ae4..2b87796dc6833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -23,7 +23,9 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -56,19 +58,23 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) } /** - * A physical plan that evaluates a [[PythonUDF]], + * A logical plan that evaluates a [[PythonUDF]]. + */ +case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + +/** + * A physical plan that evaluates a [[PythonUDF]]. */ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends EvalPythonExec(udfs, output, child) { private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) protected override def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -80,10 +86,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( - funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(batchIter, context.partitionId(), context) + funcs, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + argOffsets, + schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5fcdcddca7d51..18992d7a9f974 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -39,15 +39,13 @@ import org.apache.spark.util.Utils */ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], schema: StructType, timeZoneId: String, - respectTimeZone: Boolean) + conf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, bufferSize, reuseWorker, evalType, argOffsets) { + funcs, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, @@ -58,31 +56,28 @@ class ArrowPythonRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - if (respectTimeZone) { - PythonRDD.writeUTF(timeZoneId, dataOut) - } else { - dataOut.writeInt(SpecialLengths.NULL) + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val arrowWriter = ArrowWriter.create(root) - - context.addTaskCompletionListener { _ => - root.close() - allocator.close() - } - - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() Utils.tryWithSafeFinally { + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + while (inputIterator.hasNext) { val nextBatch = inputIterator.next() @@ -94,8 +89,21 @@ class ArrowPythonRunner( writer.writeBatch() arrowWriter.reset() } - } { + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. root.close() allocator.close() } @@ -121,7 +129,7 @@ class ArrowPythonRunner( private var schema: StructType = _ private var vectors: Array[ColumnVector] = _ - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => if (reader != null) { reader.close(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index f4d83e8dc7c2b..b08b7e60e130b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} +/** + * A logical plan that evaluates a [[PythonUDF]] + */ +case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + /** * A physical plan that evaluates a [[PythonUDF]] */ @@ -36,8 +43,6 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi protected override def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -68,8 +73,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonUDFRunner( - funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) + val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 860dc78c1dd1b..942a6db57416e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -78,8 +78,6 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil protected def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -87,8 +85,6 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) inputRDD.mapPartitions { iter => val context = TaskContext.get() @@ -97,7 +93,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener { ctx => + context.addTaskCompletionListener[Unit] { ctx => queue.close() } @@ -129,7 +125,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } val outputRowIterator = evaluate( - pyFuncs, bufferSize, reuseWorker, argOffsets, projectedRowIter, schema, context) + pyFuncs, argOffsets, projectedRowIter, schema, context) val joined = new JoinedRow val resultProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 9d56f48249982..90b5325919e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -21,11 +21,11 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** @@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupAggPandasUDF(e) || + PythonUDF.isGroupedAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } @@ -92,38 +92,54 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { +object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { - private def hasPythonUDF(e: Expression): Boolean = { + private type EvalType = Int + private type EvalTypeChecker = EvalType => Boolean + + private def hasScalarPythonUDF(e: Expression): Boolean = { e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { e.children match { // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u) + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasPythonUDF) + case children => !children.exists(hasScalarPythonUDF) } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) - case e => e.children.flatMap(collectEvaluatableUDF) + private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = { + // Eval type checker is set once when we find the first evaluable UDF and its value + // shouldn't change later. + // Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only + // extract UDFs of the same eval type) + var evalTypeChecker: Option[EvalTypeChecker] = None + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.isEmpty => + evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType) + Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.get(udf.evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs) + } + + expressions.flatMap(collectEvaluableUDFs) } - def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker - // Therefore we don't need to extract the UDFs - case plan: FlatMapGroupsInPandasExec => plan - case plan: SparkPlan => extract(plan) + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case plan: LogicalPlan => extract(plan) } /** * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ - private def extract(plan: SparkPlan): SparkPlan = { - val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + private def extract(plan: LogicalPlan): LogicalPlan = { + val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) // ignore the PythonUDF that come from second/third aggregate, which is not used .filter(udf => udf.references.subsetOf(plan.inputSet)) if (udfs.isEmpty) { @@ -134,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val prunedChildren = plan.children.map { child => val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq if (allNeededOutput.length != child.output.length) { - ProjectExec(allNeededOutput, child) + Project(allNeededOutput, child) } else { child } @@ -163,11 +179,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => - BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) + BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child) case _ => - throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") + throw new AnalysisException( + "Expected either Scalar Pandas UDFs or Batched UDFs but got both") } attributeMap ++= validUdfs.zip(resultAttrs) @@ -191,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - ProjectExec(plan.output, newPlan) + Project(plan.output, newPlan) } else { newPlan } @@ -200,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // Split the original FilterExec to two FilterExecs. Only push down the first few predicates // that are all deterministic. - private def trySplitFilter(plan: SparkPlan): SparkPlan = { + private def trySplitFilter(plan: LogicalPlan): LogicalPlan = { plan match { - case filter: FilterExec => + case filter: Filter => val (candidates, nonDeterministic) = splitConjunctivePredicates(filter.condition).partition(_.deterministic) - val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) + val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) if (pushDown.nonEmpty) { - val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) + val newChild = Filter(pushDown.reduceLeft(And), filter.child) + Filter((rest ++ nonDeterministic).reduceLeft(And), newChild) } else { filter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 513e174c7733e..e9cff1a5a2007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -73,11 +74,9 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Deduplicate the grouping attributes. // If a grouping attribute also appears in data attributes, then we don't need to send the @@ -139,10 +138,12 @@ case class FlatMapGroupsInPandasExec( val context = TaskContext.get() val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(grouped, context.partitionId(), context) + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + argOffsets, + dedupSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala new file mode 100644 index 0000000000000..a4e9b3305052f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{NextIterator, Utils} + +class PythonForeachWriter(func: PythonFunction, schema: StructType) + extends ForeachWriter[UnsafeRow] { + + private lazy val context = TaskContext.get() + private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer( + context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) + private lazy val inputRowIterator = buffer.iterator + + private lazy val inputByteIterator = { + EvaluatePython.registerPicklers() + val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + new SerDeUtil.AutoBatchedPickler(objIterator) + } + + private lazy val pythonRunner = { + PythonRunner(func) + } + + private lazy val outputIterator = + pythonRunner.compute(inputByteIterator, context.partitionId(), context) + + override def open(partitionId: Long, version: Long): Boolean = { + outputIterator // initialize everything + TaskContext.get.addTaskCompletionListener[Unit] { _ => buffer.close() } + true + } + + override def process(value: UnsafeRow): Unit = { + buffer.add(value) + } + + override def close(errorOrNull: Throwable): Unit = { + buffer.allRowsAdded() + if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + } +} + +object PythonForeachWriter { + + /** + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. + * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader + * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python + * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator + * are blocking, that is, it blocks until new data is available or all data has been added. + * + * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue + * across memory and local disk. However, HybridRowQueue is designed to be used only with + * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not + * try to read n+1 rows if the writer has only written n rows at any point of time. This + * assumption is not true for PythonForeachWriter where rows may be added at a different rate as + * they are consumed by the python worker. Hence, to maintain the invariant of the reader being + * behind the writer while using HybridRowQueue, the buffer does the following + * - Keeps a count of the rows in the HybridRowQueue + * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not + * try to read more rows than what has been written. + * + * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed + * from that of ArrayBlockingQueue. + */ + class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int) + extends Logging { + private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields) + private val lock = new ReentrantLock() + private val unblockRemove = lock.newCondition() + + // All of these are guarded by `lock` + private var count = 0L + private var allAdded = false + private var exception: Throwable = null + + val iterator = new NextIterator[UnsafeRow] { + override protected def getNext(): UnsafeRow = { + val row = remove() + if (row == null) finished = true + row + } + override protected def close(): Unit = { } + } + + def add(row: UnsafeRow): Unit = withLock { + assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count += 1 + unblockRemove.signal() + logTrace(s"Added $row, $count left") + } + + private def remove(): UnsafeRow = withLock { + while (count == 0 && !allAdded && exception == null) { + unblockRemove.await(100, TimeUnit.MILLISECONDS) + } + + // If there was any error in the adding thread, then rethrow it in the removing thread + if (exception != null) throw exception + + if (count > 0) { + val row = queue.remove() + assert(row != null, "HybridRowQueue.remove() returned null " + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count -= 1 + logTrace(s"Removed $row, $count left") + row + } else { + null + } + } + + def allRowsAdded(): Unit = withLock { + allAdded = true + unblockRemove.signal() + } + + def close(): Unit = { queue.close() } + + private def withLock[T](f: => T): T = { + lock.lockInterruptibly() + try { f } catch { + case e: Throwable => + if (exception == null) exception = e + throw e + } finally { lock.unlock() } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index e28def1c4b423..cc61faa7e7051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -29,12 +29,10 @@ import org.apache.spark.api.python._ */ class PythonUDFRunner( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, bufferSize, reuseWorker, evalType, argOffsets) { + funcs, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala new file mode 100644 index 0000000000000..27bed1137e5b3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +case class WindowInPandasExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + // Extract window expressions and window functions + val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) + + val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val context = TaskContext.get() + + val grouped = if (partitionSpec.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, partitionSpec, child.output) + } + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener[Unit] { _ => + queue.close() + } + + val inputProj = UnsafeProjection.create(allInputs, child.output) + val pythonInput = grouped.map { case (_, rows) => + rows.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + inputProj(row) + } + } + + val windowFunctionResult = new ArrowPythonRunner( + pyFuncs, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + argOffsets, + windowInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(pythonInput, context.partitionId(), context) + + val joined = new JoinedRow + val resultProj = createResultProjection(expressions) + + windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, windowOutput) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 685d5841ab551..bea652cc33076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -157,7 +157,7 @@ object StatFunctions extends Logging { cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + - s"for columns with dataType ${data.get.dataType} not supported.") + s"for columns with dataType ${data.get.dataType.catalogString} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala index 5b114242558dc..0063318db332d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala @@ -22,6 +22,9 @@ import java.nio.charset.StandardCharsets._ import scala.io.{Source => IOSource} +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + import org.apache.spark.sql.SparkSession /** @@ -43,36 +46,28 @@ import org.apache.spark.sql.SparkSession * line 2: metadata (optional json string) */ class CommitLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[String](sparkSession, path) { + extends HDFSMetadataLog[CommitMetadata](sparkSession, path) { import CommitLog._ - def add(batchId: Long): Unit = { - super.add(batchId, EMPTY_JSON) - } - - override def add(batchId: Long, metadata: String): Boolean = { - throw new UnsupportedOperationException( - "CommitLog does not take any metadata, use 'add(batchId)' instead") - } - - override protected def deserialize(in: InputStream): String = { + override protected def deserialize(in: InputStream): CommitMetadata = { // called inside a try-finally where the underlying stream is closed in the caller val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file in the offset commit log") } parseVersion(lines.next.trim, VERSION) - EMPTY_JSON + val metadataJson = if (lines.hasNext) lines.next else EMPTY_JSON + CommitMetadata(metadataJson) } - override protected def serialize(metadata: String, out: OutputStream): Unit = { + override protected def serialize(metadata: CommitMetadata, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller out.write(s"v${VERSION}".getBytes(UTF_8)) out.write('\n') // write metadata - out.write(EMPTY_JSON.getBytes(UTF_8)) + out.write(metadata.json.getBytes(UTF_8)) } } @@ -81,3 +76,13 @@ object CommitLog { private val EMPTY_JSON = "{}" } + +case class CommitMetadata(nextBatchWatermarkMs: Long = 0) { + def json: String = Serialization.write(this)(CommitMetadata.format) +} + +object CommitMetadata { + implicit val format = Serialization.formats(NoTypeHints) + + def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala new file mode 100644 index 0000000000000..c9c2ebc875f28 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.SparkEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset + +case class ContinuousRecordPartitionOffset(partitionId: Int, offset: Int) extends PartitionOffset +case class GetRecord(offset: ContinuousRecordPartitionOffset) + +/** + * A RPC end point for continuous readers to poll for + * records from the driver. + * + * @param buckets the data buckets. Each bucket contains a sequence of items to be + * returned for a partition. The number of buckets should be equal to + * to the number of partitions. + * @param lock a lock object for locking the buckets for read + */ +class ContinuousRecordEndpoint(buckets: Seq[Seq[Any]], lock: Object) + extends ThreadSafeRpcEndpoint { + + private var startOffsets: Seq[Int] = List.fill(buckets.size)(0) + + /** + * Sets the start offset. + * + * @param offsets the base offset per partition to be used + * while retrieving the data in {#receiveAndReply}. + */ + def setStartOffsets(offsets: Seq[Int]): Unit = { + lock.synchronized { + startOffsets = offsets + } + } + + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + /** + * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message, + * `SparkException` will be thrown and sent to `onError`. + */ + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousRecordPartitionOffset(partitionId, offset)) => + lock.synchronized { + val bufOffset = offset - startOffsets(partitionId) + val buf = buckets(partitionId) + val record = if (buf.size <= bufOffset) None else Some(buf(bufOffset)) + + context.reply(record.map(InternalRow(_))) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 8c016abc5b643..103fa7ce9066d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -50,7 +50,7 @@ class FileStreamSource( @transient private val fs = new Path(path).getFileSystem(hadoopConf) private val qualifiedBasePath: Path = { - fs.makeQualified(new Path(path)) // can contains glob patterns + fs.makeQualified(new Path(path)) // can contain glob patterns } private val optionsWithPartitionBasePath = sourceOptions.optionMapWithoutPath ++ { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 80769d728b8f1..bfe7d00f56048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -52,6 +50,7 @@ case class FlatMapGroupsWithStateExec( outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], + stateFormatVersion: Int, outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], @@ -60,32 +59,15 @@ case class FlatMapGroupsWithStateExec( ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - private val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } + private[sql] val stateManager = + createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -97,6 +79,18 @@ case class FlatMapGroupsWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + timeoutConf match { + case ProcessingTimeTimeout => + true // Always run batches to process timeouts + case EventTimeTimeout => + // Process another non-data batch only if the watermark has changed in this executed plan + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + case _ => + false + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -113,11 +107,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -126,13 +120,12 @@ case class FlatMapGroupsWithStateExec( case _ => iter } - // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -147,7 +140,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -156,14 +149,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -172,20 +157,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -194,12 +178,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutPairs.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -209,22 +192,19 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -232,50 +212,24 @@ case class FlatMapGroupsWithStateExec( watermarkPresent) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -284,28 +238,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 1a83c884d55bd..fad287e28877d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -22,7 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} -import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp +import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, ExpressionWithRandomSeed} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.util.Utils /** * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] @@ -59,7 +60,8 @@ class IncrementalExecution( StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: StreamingRelationStrategy :: - StreamingDeduplicationStrategy :: Nil + StreamingDeduplicationStrategy :: + StreamingGlobalLimitStrategy(outputMode) :: Nil } private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) @@ -76,6 +78,7 @@ class IncrementalExecution( case ts @ CurrentBatchTimestamp(timestamp, _, _) => logInfo(s"Current batch timestamp = $timestamp") ts.toLiteral + case e: ExpressionWithRandomSeed => e.withNewSeed(Utils.random.nextLong()) } } @@ -99,19 +102,21 @@ class IncrementalExecution( val state = new Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, None, None, + case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, UnaryExecNode(agg, - StateStoreRestoreExec(_, None, child))) => + StateStoreRestoreExec(_, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( keys, Some(aggStateInfo), + stateFormatVersion, child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => @@ -134,8 +139,12 @@ class IncrementalExecution( stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs)) - ) + Some(offsetSeqMetadata.batchWatermarkMs))) + + case l: StreamingGlobalLimitExec => + l.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + outputMode = Some(outputMode)) } } @@ -143,4 +152,14 @@ class IncrementalExecution( /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } + + /** + * Should the MicroBatchExecution run another batch based on this execution and the current + * updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + executedPlan.collect { + case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata) + }.exists(_ == true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index 66b11ecddf233..8709822acff12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.streaming +import java.text.SimpleDateFormat + import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.{Source => CodahaleSource} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.streaming.StreamingQueryProgress /** @@ -39,6 +42,23 @@ class MetricsReporter( registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0) registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L) + private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 + timestampFormat.setTimeZone(DateTimeUtils.getTimeZone("UTC")) + + registerGauge("eventTime-watermark", + progress => convertStringDateToMillis(progress.eventTime.get("watermark")), 0L) + + registerGauge("states-rowsTotal", _.stateOperators.map(_.numRowsTotal).sum, 0L) + registerGauge("states-usedBytes", _.stateOperators.map(_.memoryUsedBytes).sum, 0L) + + private def convertStringDateToMillis(isoUtcDateStr: String) = { + if (isoUtcDateStr != null) { + timestampFormat.parse(isoUtcDateStr).getTime + } else { + 0L + } + } + private def registerGauge[T]( name: String, f: StreamingQueryProgress => T, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6e231970f4a22..b1cafd67820c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming -import java.util.Optional - import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} @@ -28,10 +26,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow +import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -52,8 +49,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = - MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val readSupportToDataSourceMap = + MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -61,6 +58,8 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } + private var watermarkTracker: WatermarkTracker = _ + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in QueryExecutionThread " + @@ -90,20 +89,19 @@ class MicroBatchExecution( StreamingExecutionRelation(source, output)(sparkSession) }) case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = dataSourceV2.createMicroBatchReader( - Optional.empty(), // user specified schema + val readSupport = dataSourceV2.createMicroBatchReadSupport( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReader [$reader] from " + + readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options + logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(reader, output)(sparkSession) + StreamingExecutionRelation(readSupport, output)(sparkSession) }) case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { @@ -124,44 +122,94 @@ class MicroBatchExecution( _logicalPlan } + /** + * Signifies whether current batch (i.e. for the batch `currentBatchId`) has been constructed + * (i.e. written to the offsetLog) and is ready for execution. + */ + private var isCurrentBatchConstructed = false + + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) + queryExecutionThread.interrupt() + queryExecutionThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) + } + logInfo(s"Query $prettyIdString was stopped") + } + /** * Repeatedly attempts to run batches as data arrives. */ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - triggerExecutor.execute(() => { - startTrigger() + val noDataBatchesEnabled = + sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled + + triggerExecutor.execute(() => { if (isActive) { + var currentBatchHasNewData = false // Whether the current batch had new data + + startTrigger() + reportTimeTaken("triggerExecution") { + // We'll do this initialization only once every start / restart if (currentBatchId < 0) { - // We'll do this initialization only once populateStartOffsets(sparkSessionForStream) - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() + logInfo(s"Stream started from $committedOffsets") } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") + + // Set this before calling constructNextBatch() so any Spark jobs executed by sources + // while getting new data have the correct description + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + + // Try to construct the next batch. This will return true only if the next batch is + // ready and runnable. Note that the current batch may be runnable even without + // new data to process as `constructNextBatch` may decide to run a batch for + // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data + // is available or not. + if (!isCurrentBatchConstructed) { + isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) + } + + // Record the trigger offset range for progress reporting *before* processing the batch + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) + + // Remember whether the current batch has data or not. This will be required later + // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed + // to false as the batch would have already processed the available data. + currentBatchHasNewData = isNewDataAvailable + + currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) + if (isCurrentBatchConstructed) { + if (currentBatchHasNewData) updateStatusMessage("Processing new data") + else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) + } else { + updateStatusMessage("Waiting for data to arrive") } } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) - if (dataAvailable) { - // Update committed offsets. - commitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data + + finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + + // Signal waiting threads. Note this must be after finishTrigger() to ensure all + // activities (progress generation, etc.) have completed before signaling. + withProgressLocked { awaitProgressLockCondition.signalAll() } + + // If the current batch has been executed, then increment the batch id and reset flag. + // Otherwise, there was no data to execute the batch and sleep for some time + if (isCurrentBatchConstructed) { currentBatchId += 1 - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) - } + isCurrentBatchConstructed = false + } else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -196,6 +244,7 @@ class MicroBatchExecution( /* First assume that we are re-executing the latest known batch * in the offset log */ currentBatchId = latestBatchId + isCurrentBatchConstructed = true availableOffsets = nextOffsets.toStreamProgress(sources) /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. */ @@ -211,13 +260,15 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) + watermarkTracker.setWatermark(metadata.batchWatermarkMs) } /* identify the current batch id: if commit log indicates we successfully processed the * latest batch id in the offset log, then we can safely move to the next batch * i.e., committedBatchId + 1 */ commitLog.getLatest() match { - case Some((latestCommittedBatchId, _)) => + case Some((latestCommittedBatchId, commitMetadata)) => if (latestBatchId == latestCommittedBatchId) { /* The last batch was successfully committed, so we can safely process a * new next batch but first: @@ -233,9 +284,10 @@ class MicroBatchExecution( // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 + isCurrentBatchConstructed = false committedOffsets ++= availableOffsets - // Construct a new batch be recomputing availableOffsets - constructNextBatch() + watermarkTracker.setWatermark( + math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs)) } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -243,19 +295,19 @@ class MicroBatchExecution( } case None => logInfo("no commit log present") } - logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + logInfo(s"Resuming at batch $currentBatchId with committed offsets " + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 - constructNextBatch() + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) } } /** * Returns true if there is any new data available to be processed. */ - private def dataAvailable: Boolean = { + private def isNewDataAvailable: Boolean = { availableOffsets.exists { case (source, available) => committedOffsets @@ -266,93 +318,65 @@ class MicroBatchExecution( } /** - * Queries all of the sources to see if any new data is available. When there is new data the - * batchId counter is incremented and a new log entry is written with the newest offsets. + * Attempts to construct a batch according to: + * - Availability of new data + * - Need for timeouts and state cleanups in stateful operators + * + * Returns true only if the next batch should be executed. + * + * Here is the high-level logic on how this constructs the next batch. + * - Check each source whether new data is available + * - Updated the query's metadata and check using the last execution whether there is any need + * to run another batch (for state clean up, etc.) + * - If either of the above is true, then construct the next batch by committing to the offset + * log that range of offsets that the next batch will process. */ - private def constructNextBatch(): Unit = { - // Check to see what new data is available. - val hasNewData = { - awaitProgressLock.lock() - try { - // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { - case s: Source => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } - case s: MicroBatchReader => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) - }.toMap - availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) - - if (dataAvailable) { - true - } else { - noNewData = true - false + private def constructNextBatch(noDataBatchesEnabled: Boolean): Boolean = withProgressLocked { + if (isCurrentBatchConstructed) return true + + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) } - } finally { - awaitProgressLock.unlock() - } - } - if (hasNewData) { - var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermarks if we find any in the plan. - if (lastExecution != null) { - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => e - }.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = watermarkMsMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - watermarkMsMap.put(index, newWatermarkMs) - } - - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!watermarkMsMap.isDefinedAt(index)) { - watermarkMsMap.put(index, 0) - } + case s: RateControlMicroBatchReadSupport => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("latestOffset") { + val startOffset = availableOffsets + .get(s).map(off => s.deserializeOffset(off.json)) + .getOrElse(s.initialOffset()) + (s, Option(s.latestOffset(startOffset))) } - - // Update the global watermark to the minimum of all watermark nodes. - // This is the safest option, because only the global watermark is fault-tolerant. Making - // it the minimum of all individual watermarks guarantees it will never advance past where - // any individual watermark operator would be if it were in a plan by itself. - if(!watermarkMsMap.isEmpty) { - val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 - if (newWatermarkMs > batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - batchWatermarkMs = newWatermarkMs - } else { - logDebug( - s"Event time didn't move: $newWatermarkMs < " + - s"$batchWatermarkMs") - } + case s: MicroBatchReadSupport => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("latestOffset") { + (s, Option(s.latestOffset())) } - } - offsetSeqMetadata = offsetSeqMetadata.copy( - batchWatermarkMs = batchWatermarkMs, - batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds - + }.toMap + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) + + // Update the query metadata + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = watermarkTracker.currentWatermark, + batchTimestampMs = triggerClock.getTimeMillis()) + + // Check whether next batch should be constructed + val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled && + Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) + val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + logTrace( + s"noDataBatchesEnabled = $noDataBatchesEnabled, " + + s"lastExecutionRequiresAnotherBatch = $lastExecutionRequiresAnotherBatch, " + + s"isNewDataAvailable = $isNewDataAvailable, " + + s"shouldConstructNextBatch = $shouldConstructNextBatch") + + if (shouldConstructNextBatch) { + // Commit the next batch offset range to the offset log updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { - assert(offsetLog.add( - currentBatchId, + assert(offsetLog.add(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId. " + @@ -369,11 +393,14 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (reader: MicroBatchReader, off) => - reader.commit(reader.deserializeOffset(off.json)) + case (readSupport: MicroBatchReadSupport, off) => + readSupport.commit(readSupport.deserializeOffset(off.json)) + case (src, _) => + throw new IllegalArgumentException( + s"Unknown source is found at constructNextBatch: $src") } } else { - throw new IllegalStateException(s"batch $currentBatchId doesn't exist") + throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") } } @@ -384,15 +411,12 @@ class MicroBatchExecution( commitLog.purge(currentBatchId - minLogEntriesToMaintain) } } + noNewData = false } else { - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() - } + noNewData = true + awaitProgressLockCondition.signalAll() } + shouldConstructNextBatch } /** @@ -400,6 +424,8 @@ class MicroBatchExecution( * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { + logDebug(s"Running batch $currentBatchId") + // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -412,30 +438,34 @@ class MicroBatchExecution( s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - case (reader: MicroBatchReader, available) - if committedOffsets.get(reader).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) - val availableV2: OffsetV2 = available match { - case v1: SerializedOffset => reader.deserializeOffset(v1.json) + + // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but + // to be compatible with streaming source v1, we return a logical plan as a new batch here. + case (readSupport: MicroBatchReadSupport, available) + if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(readSupport).map { + off => readSupport.deserializeOffset(off.json) + } + val endOffset: OffsetV2 = available match { + case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) - logDebug(s"Retrieving data from $reader: $current -> $availableV2") + val startOffset = current.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) + logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") - val (source, options) = reader match { + val (source, options) = readSupport match { // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` // implementation. We provide a fake one here for explain. case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] // Provide a fake value here just in case something went wrong, e.g. the reader gives // a wrong `equals` implementation. - case _ => readerToDataSourceMap.getOrElse(reader, { + case _ => readSupportToDataSourceMap.getOrElse(readSupport, { FakeDataSourceV2 -> Map.empty[String, String] }) } - Some(reader -> StreamingDataSourceV2Relation( - reader.readSchema().toAttributes, source, options, reader)) + Some(readSupport -> StreamingDataSourceV2Relation( + readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) case _ => None } } @@ -469,18 +499,13 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamWriteSupport => - val writer = s.createStreamWriter( + case s: StreamingWriteSupportProvider => + val writer = s.createStreamingWriteSupport( s"$runId", newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - if (writer.isInstanceOf[SupportsWriteInternalRow]) { - WriteToDataSourceV2( - new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) - } else { - WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) - } + WriteToDataSourceV2(new MicroBatchWritSupport(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -506,24 +531,23 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamWriteSupport => + case _: StreamingWriteSupportProvider => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } } } - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() + withProgressLocked { + watermarkTracker.updateWatermark(lastExecution.executedPlan) + commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) + committedOffsets ++= availableOffsets } + logDebug(s"Completed batch ${currentBatchId}") } /** Execute a function while locking the stream from making an progress */ - private[sql] def withProgressLocked(f: => Unit): Unit = { + private[sql] def withProgressLocked[T](f: => T): T = { awaitProgressLock.lock() try { f @@ -531,10 +555,6 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } - - private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { - Optional.ofNullable(scalaOption.orNull) - } } object MicroBatchExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java index 80aa5505db991..43ad4b3384ec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java @@ -19,8 +19,8 @@ /** * This is an internal, deprecated interface. New source implementations should use the - * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported - * in the long term. + * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be + * supported in the long term. * * This class will be removed in a future release. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73945b39b8967..73cf355dbe758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} +import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -39,7 +40,9 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * cannot be serialized). */ def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { - assert(sources.size == offsets.size) + assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + + s"Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } @@ -84,7 +87,27 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) - private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + private val relevantSQLConfs = Seq( + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + + /** + * Default values of relevant configurations that are used for backward compatibility. + * As new configurations are added to the metadata, existing checkpoints may not have those + * confs. The values in this list ensures that the confs without recovered values are + * set to a default value that ensure the same behavior of the streaming query as it was before + * the restart. + * + * Note, that this is optional; set values here if you *have* to override existing session conf + * with a specific default value for ensuring same behavior of the query as before. + */ + private val relevantSQLConfDefaultValues = Map[String, String]( + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> + FlatMapGroupsWithStateExecHelper.legacyVersion.toString, + STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> + StreamingAggregationStateManager.legacyVersion.toString + ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) @@ -113,8 +136,22 @@ object OffsetSeqMetadata extends Logging { case None => // For backward compatibility, if a config was not recorded in the offset log, - // then log it, and let the existing conf value in SparkSession prevail. - logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + // then either inject a default value (if specified in `relevantSQLConfDefaultValues`) or + // let the existing conf value in SparkSession prevail. + relevantSQLConfDefaultValues.get(confKey) match { + + case Some(defaultValue) => + sessionConf.set(confKey, defaultValue) + logWarning(s"Conf '$confKey' was not found in the offset log, " + + s"using default value '$defaultValue'") + + case None => + val valueStr = sessionConf.getOption(confKey).map { v => + s" Using existing session conf value '$v'." + }.getOrElse { " No value set in session conf." } + logWarning(s"Conf '$confKey' was not found in the offset log. $valueStr") + + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d1e5be9c12762..417b6b39366ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -22,12 +22,20 @@ import java.util.{Date, UUID} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.control.NonFatal + +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWritSupport +import org.apache.spark.sql.sources.v2.CustomMetrics +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWriteSupport, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -54,8 +62,6 @@ trait ProgressReporter extends Logging { protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution protected def newData: Map[BaseStreamingSource, LogicalPlan] - protected def availableOffsets: StreamProgress - protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata @@ -66,8 +72,11 @@ trait ProgressReporter extends Logging { // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L private var currentTriggerEndTimestamp = -1L + private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _ + private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _ // TODO: Restore this from the checkpoint when possible. private var lastTriggerStartTimestamp = -1L + private val currentDurationsMs = new mutable.HashMap[String, Long]() /** Flag that signals whether any error with input metrics have already been logged */ @@ -112,9 +121,20 @@ trait ProgressReporter extends Logging { lastTriggerStartTimestamp = currentTriggerStartTimestamp currentTriggerStartTimestamp = triggerClock.getTimeMillis() currentStatus = currentStatus.copy(isTriggerActive = true) + currentTriggerStartOffsets = null + currentTriggerEndOffsets = null currentDurationsMs.clear() } + /** + * Record the offsets range this trigger will process. Call this before updating + * `committedOffsets` in `StreamExecution` to make sure that the correct range is recorded. + */ + protected def recordTriggerOffsets(from: StreamProgress, to: StreamProgress): Unit = { + currentTriggerStartOffsets = from.mapValues(_.json) + currentTriggerEndOffsets = to.mapValues(_.json) + } + private def updateProgress(newProgress: StreamingQueryProgress): Unit = { progressBuffer.synchronized { progressBuffer += newProgress @@ -128,6 +148,7 @@ trait ProgressReporter extends Logging { /** Finalizes the query progress and adds it to list of recent status updates. */ protected def finishTrigger(hasNewData: Boolean): Unit = { + assert(currentTriggerStartOffsets != null && currentTriggerEndOffsets != null) currentTriggerEndTimestamp = triggerClock.getTimeMillis() val executionStats = extractExecutionStats(hasNewData) @@ -141,18 +162,51 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") - val sourceProgress = sources.map { source => + // extracts and validates custom metrics from readers and writers + def extractMetrics( + getMetrics: () => Option[CustomMetrics], + onInvalidMetrics: (Exception) => Unit): Option[String] = { + try { + getMetrics().map(m => { + val json = m.json() + parse(json) + json + }) + } catch { + case ex: Exception if NonFatal(ex) => + onInvalidMetrics(ex) + None + } + } + + val sourceProgress = sources.distinct.map { source => + val customReaderMetrics = source match { + case s: SupportsCustomReaderMetrics => + extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) + + case _ => None + } + val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, - startOffset = committedOffsets.get(source).map(_.json).orNull, - endOffset = availableOffsets.get(source).map(_.json).orNull, + startOffset = currentTriggerStartOffsets.get(source).orNull, + endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, - processedRowsPerSecond = numRecords / processingTimeSec + processedRowsPerSecond = numRecords / processingTimeSec, + customReaderMetrics.orNull ) } - val sinkProgress = new SinkProgress(sink.toString) + + val customWriterMetrics = extractWriteSupport() match { + case Some(s: SupportsCustomWriterMetrics) => + extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) + + case _ => None + } + + val sinkProgress = new SinkProgress(sink.toString, customWriterMetrics.orNull) val newProgress = new StreamingQueryProgress( id = id, @@ -181,6 +235,18 @@ trait ProgressReporter extends Logging { currentStatus = currentStatus.copy(isTriggerActive = false) } + /** Extract writer from the executed query plan. */ + private def extractWriteSupport(): Option[StreamingWriteSupport] = { + if (lastExecution == null) return None + lastExecution.executedPlan.collect { + case p if p.isInstanceOf[WriteToDataSourceV2Exec] => + p.asInstanceOf[WriteToDataSourceV2Exec].writeSupport + }.headOption match { + case Some(w: MicroBatchWritSupport) => Some(w.writeSupport) + case _ => None + } + } + /** Extract statistics about stateful operators from the executed query plan. */ private def extractStateOperatorMetrics(hasNewData: Boolean): Seq[StateOperatorProgress] = { if (lastExecution == null) return Nil @@ -207,62 +273,126 @@ trait ProgressReporter extends Logging { return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp) } - // We want to associate execution plan leaves to sources that generate them, so that we match - // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. - // Consider the translation from the streaming logical plan to the final executed plan. - // - // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan - // - // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan - // - Each logical plan leaf will be associated with a single streaming source. - // - There can be multiple logical plan leaves associated with a streaming source. - // - There can be leaves not associated with any streaming source, because they were - // generated from a batch source (e.g. stream-batch joins) - // - // 2. Assuming that the executed plan has same number of leaves in the same order as that of - // the trigger logical plan, we associate executed plan leaves with corresponding - // streaming sources. - // - // 3. For each source, we sum the metrics of the associated execution plan leaves. - // - val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => - logicalPlan.collectLeaves().map { leaf => leaf -> source } + val numInputRows = extractSourceToNumInputRows() + + val eventTimeStats = lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + val stats = e.eventTimeStats.value + Map( + "max" -> stats.max, + "min" -> stats.min, + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) + }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp + + ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } + + /** Extract number of input sources for each streaming source in plan */ + private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + + import java.util.IdentityHashMap + import scala.collection.JavaConverters._ + + def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + } + + val onlyDataSourceV2Sources = { + // Check whether the streaming query's logical plan has only V2 data sources + val allStreamingLeaves = + logicalPlan.collect { case s: StreamingExecutionRelation => s } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } } - val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming - val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[BaseStreamingSource, Long] = + + if (onlyDataSourceV2Sources) { + // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data + // from a V2 source and has a direct reference to the V2 source that generated it. Each + // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, + // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as + // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or + // even multiple times) points and considering it twice will lead to double counting. We + // can't dedup them using their hashcode either because two different instances of + // DataSourceV2ScanExec can have the same hashcode but account for separate sets of + // records read, and deduping them to consider only one of them would be undercounting the + // records read. Therefore the right way to do this is to consider the unique instances of + // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. + // Hence we calculate in the following way. + // + // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. + // + // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. + // + // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with + // self-unions or self-joins). Add up the number of rows for each unique source. + val uniqueStreamingExecLeavesMap = + new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() + + lastExecution.executedPlan.collectLeaves().foreach { + case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => + uniqueStreamingExecLeavesMap.put(s, s) + case _ => + } + + val sourceToInputRowsTuples = + uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource] + source -> numRows + }.toSeq + logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) + sumRows(sourceToInputRowsTuples) + } else { + + // Since V1 source do not generate execution plan leaves that directly link with source that + // generated it, we can only do a best-effort association between execution plan leaves to the + // sources. This is known to fail in a few cases, see SPARK-24050. + // + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming + val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } } - val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) source -> numRows } - sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + sumRows(sourceToInputRowsTuples) } else { if (!metricWarningLogged) { def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( "Could not report metrics as number leaves in trigger logical plan did not match that" + - s" of the execution plan:\n" + - s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + - s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") metricWarningLogged = true } Map.empty } - - val eventTimeStats = lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - val stats = e.eventTimeStats.value - Map( - "max" -> stats.max, - "min" -> stats.min, - "avg" -> stats.avg.toLong).mapValues(formatTimestamp) - }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp - - ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } } /** Records the duration of running `body` for the next query progress update. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala new file mode 100644 index 0000000000000..1be071614d92e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.types.StructType + +/** + * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to + * carry schema and offsets for streaming data sources. + */ +class SimpleStreamingScanConfigBuilder( + schema: StructType, + start: Offset, + end: Option[Offset] = None) + extends ScanConfigBuilder { + + override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end) +} + +case class SimpleStreamingScanConfig( + readSchema: StructType, + start: Offset, + end: Option[Offset]) + extends ScanConfig diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3fc8c7887896a..a39bb715c9913 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -378,29 +378,11 @@ abstract class StreamExecution( } } - /** - * Signals to the thread executing micro-batches that it should stop running after the next - * batch. This method blocks until the thread stops running. - */ - override def stop(): Unit = { - // Set the state to TERMINATED so that the batching thread knows that it was interrupted - // intentionally - state.set(TERMINATED) - if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) - queryExecutionThread.interrupt() - queryExecutionThread.join() - // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak - sparkSession.sparkContext.cancelJobGroup(runId.toString) - } - logInfo(s"Query $prettyIdString was stopped") - } - /** * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset, timeoutMs: Long): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets @@ -416,7 +398,7 @@ abstract class StreamExecution( while (notDone) { awaitProgressLock.lock() try { - awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS) + awaitProgressLockCondition.await(timeoutMs, TimeUnit.MILLISECONDS) if (streamDeathCause != null) { throw streamDeathCause } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala new file mode 100644 index 0000000000000..bf4af60c8cf03 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.state.StateStoreOps +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType} +import org.apache.spark.util.CompletionIterator + +/** + * A physical operator for executing a streaming limit, which makes sure no more than streamLimit + * rows are returned. This operator is meant for streams in Append mode only. + */ +case class StreamingGlobalLimitExec( + streamLimit: Long, + child: SparkPlan, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None) + extends UnaryExecNode with StateStoreWriter { + + private val keySchema = StructType(Array(StructField("key", NullType))) + private val valueSchema = StructType(Array(StructField("value", LongType))) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append, + "StreamingGlobalLimitExec is only valid for streams in Append output mode") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + keySchema, + valueSchema, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null))) + val numOutputRows = longMetric("numOutputRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val updatesStartTimeNs = System.nanoTime + + val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L) + var cumulativeRowCount = preBatchRowCount + + val result = iter.filter { r => + val x = cumulativeRowCount < streamLimit + if (x) { + cumulativeRowCount += 1 + } + x + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + if (cumulativeRowCount > preBatchRowCount) { + numUpdatedStateRows += 1 + numOutputRows += cumulativeRowCount - preBatchRowCount + store.put(key, getValueRow(cumulativeRowCount)) + } + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil + + private def getValueRow(value: Long): UnsafeRow = { + UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index f02d3a2c3733f..4b696dfa57359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -66,6 +66,7 @@ case class StreamingExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString @@ -82,7 +83,7 @@ case class StreamingExecutionRelation( // We have to pack in the V1 data source as a shim, for the case when a source implements // continuous processing (which is always V2) but only has V1 microbatch support. We don't -// know at read time whether the query is conntinuous or not, so we need to be able to +// know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** * Used to link a [[DataSourceV2]] into a streaming @@ -97,6 +98,7 @@ case class StreamingRelationV2( output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = sourceName @@ -111,11 +113,12 @@ case class StreamingRelationV2( * Used to link a [[DataSourceV2]] into a continuous processing execution. */ case class ContinuousExecutionRelation( - source: ContinuousReadSupport, + source: ContinuousReadSupportProvider, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index fa7c8ee906ecd..50cf971e4ec3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output @@ -187,6 +187,17 @@ case class StreamingSymmetricHashJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + val watermarkUsedForStateCleanup = + stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty + + // Latest watermark value is more than that used in this previous executed plan + val watermarkHasChanged = + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + + watermarkUsedForStateCleanup && watermarkHasChanged + } + protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) @@ -319,8 +330,7 @@ case class StreamingSymmetricHashJoinExec( // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal // needs to be done greedily by immediately consuming the returned iterator. val cleanupIter = joinType match { - case Inner => - leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case LeftOuter => rightSideJoiner.removeOldState() case RightOuter => leftSideJoiner.removeOldState() case _ => throwBadJoinTypeException() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 4aba76cad367e..2d4c3c10e6445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -144,7 +144,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { // Join keys of both sides generate rows of the same fields, that is, same sequence of data - // types. If one side (say left side) has a column (say timestmap) that has a watermark on it, + // types. If one side (say left side) has a column (say timestamp) that has a watermark on it, // then it will never consider joining keys that are < state key watermark (i.e. event time // watermark). On the other side (i.e. right side), even if there is no watermark defined, // there has to be an equivalent column (i.e., timestamp). And any right side data that has the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala new file mode 100644 index 0000000000000..7b30db44a2090 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf + +/** + * Policy to define how to choose a new global watermark value if there are + * multiple watermark operators in a streaming query. + */ +sealed trait MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long +} + +object MultipleWatermarkPolicy { + val DEFAULT_POLICY_NAME = "min" + + def apply(policyName: String): MultipleWatermarkPolicy = { + policyName.toLowerCase match { + case DEFAULT_POLICY_NAME => MinWatermark + case "max" => MaxWatermark + case _ => + throw new IllegalArgumentException(s"Could not recognize watermark policy '$policyName'") + } + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. + * Note that this is the safe (hence default) policy as the global watermark will advance + * only if all the individual operator watermarks have advanced. In other words, in a + * streaming query with multiple input streams and watermarks defined on all of them, + * the global watermark will advance as slowly as the slowest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most conservative one. + */ +case object MinWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.min + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. So the + * global watermark will advance if any of the individual operator watermarks has advanced. + * In other words, in a streaming query with multiple input streams and watermarks defined on all + * of them, the global watermark will advance as fast as the fastest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most aggressive one and + * may lead to unexpected behavior if the data of the slow stream is delayed. + */ +case object MaxWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.max + } +} + +/** Tracks the watermark value of a streaming query based on a given `policy` */ +case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { + private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private var globalWatermarkMs: Long = 0 + + def setWatermark(newWatermarkMs: Long): Unit = synchronized { + globalWatermarkMs = newWatermarkMs + } + + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { + val watermarkOperators = executedPlan.collect { + case e: EventTimeWatermarkExec => e + } + if (watermarkOperators.isEmpty) return + + watermarkOperators.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = operatorToWatermarkMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + operatorToWatermarkMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!operatorToWatermarkMap.isDefinedAt(index)) { + operatorToWatermarkMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq) + if (chosenGlobalWatermark > globalWatermarkMs) { + logInfo(s"Updating event-time watermark from $globalWatermarkMs to $chosenGlobalWatermark ms") + globalWatermarkMs = chosenGlobalWatermark + } else { + logDebug(s"Event time watermark didn't move: $chosenGlobalWatermark < $globalWatermarkMs") + } + } + + def currentWatermark: Long = synchronized { globalWatermarkMs } +} + +object WatermarkTracker { + def apply(conf: RuntimeConfig): WatermarkTracker = { + // If the session has been explicitly configured to use non-default policy then use it, + // otherwise use the default `min` policy as thats the safe thing to do. + // When recovering from a checkpoint location, it is expected that the `conf` will already + // be configured with the value present in the checkpoint. If there is no policy explicitly + // saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced + // through defaults specified in OffsetSeqMetadata.setSessionConf(). + val policyName = conf.get( + SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) + new WatermarkTracker(MultipleWatermarkPolicy(policyName)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index cfba1001c6de0..9c5c16f4f5d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -31,16 +31,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with StreamWriteSupport + with StreamingWriteSupportProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new ConsoleWriter(schema, options) + options: DataSourceOptions): StreamingWriteSupport = { + new ConsoleWriteSupport(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala new file mode 100644 index 0000000000000..5f60343bacfaa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark.{HashPartitioner, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD} + +/** + * Physical plan for coalescing a continuous processing plan. + * + * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. + */ +case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan { + override def output: Seq[Attribute] = child.output + + override def children: Seq[SparkPlan] = child :: Nil + + override def outputPartitioning: Partitioning = SinglePartition + + override def doExecute(): RDD[InternalRow] = { + assert(numPartitions == 1) + new ContinuousCoalesceRDD( + sparkContext, + numPartitions, + conf.continuousStreamingExecutorQueueSize, + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong, + child.execute()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala new file mode 100644 index 0000000000000..aec756c0eb2a4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.shuffle._ +import org.apache.spark.util.ThreadUtils + +case class ContinuousCoalesceRDDPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener[Unit] { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } + // This flag will be flipped on the executors to indicate that the threads processing + // partitions of the write-side RDD have been started. These will run indefinitely + // asynchronously as epochs of the coalesce RDD complete on the read side. + private[continuous] var writersInitialized: Boolean = false +} + +/** + * RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local + * continuous shuffle, and then reads them in the task thread using `reader`. + */ +class ContinuousCoalesceRDD( + context: SparkContext, + numPartitions: Int, + readerQueueSize: Int, + epochIntervalMs: Long, + prev: RDD[InternalRow]) + extends RDD[InternalRow](context, Nil) { + + // When we support more than 1 target partition, we'll need to figure out how to pass in the + // required partitioner. + private val outputPartitioner = new HashPartitioner(1) + + private val readerEndpointNames = (0 until numPartitions).map { i => + s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}" + } + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousCoalesceRDDPartition( + partIndex, + readerEndpointNames(partIndex), + readerQueueSize, + prev.getNumPartitions, + epochIntervalMs) + }.toArray + } + + private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool( + prev.getNumPartitions, + this.name) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val part = split.asInstanceOf[ContinuousCoalesceRDDPartition] + + if (!part.writersInitialized) { + val rpcEnv = SparkEnv.get.rpcEnv + + // trigger lazy initialization + part.endpoint + val endpointRefs = readerEndpointNames.map { endpointName => + rpcEnv.setupEndpointRef(rpcEnv.address, endpointName) + } + + val runnables = prev.partitions.map { prevSplit => + new Runnable() { + override def run(): Unit = { + TaskContext.setTaskContext(context) + + val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter( + prevSplit.index, outputPartitioner, endpointRefs.toArray) + + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + while (!context.isInterrupted() && !context.isCompleted()) { + writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) + // Note that current epoch is a non-inheritable thread local, so each writer thread + // can properly increment its own epoch without affecting the main task thread. + EpochTracker.incrementCurrentEpoch() + } + } + } + } + + context.addTaskCompletionListener[Unit] { ctx => + threadPool.shutdownNow() + } + + part.writersInitialized = true + + runnables.foreach(threadPool.execute) + } + + part.reader.read() + } + + override def clearDependencies(): Unit = { + throw new IllegalStateException("Continuous RDDs cannot be checkpointed") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala new file mode 100644 index 0000000000000..b68f67e0b22d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousPartitionReaderFactory +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.NextIterator + +class ContinuousDataSourceRDDPartition( + val index: Int, + val inputPartition: InputPartition) + extends Partition with Serializable { + + // This is semantically a lazy val - it's initialized once the first time a call to + // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across + // all compute() calls for a partition. This ensures that one compute() picks up where the + // previous one ended. + // We don't make it actually a lazy val because it needs input which isn't available here. + // This will only be initialized on the executors. + private[continuous] var queueReader: ContinuousQueuedDataReader = _ +} + +/** + * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]] + * to read from the remote source, and polls that queue for incoming rows. + * + * Note that continuous processing calls compute() multiple times, and the same + * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split. + */ +class ContinuousDataSourceRDD( + sc: SparkContext, + dataQueueSize: Int, + epochPollIntervalMs: Long, + private val inputPartitions: Seq[InputPartition], + schema: StructType, + partitionReaderFactory: ContinuousPartitionReaderFactory) + extends RDD[InternalRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + inputPartitions.zipWithIndex.map { + case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) + }.toArray + } + + private def castPartition(split: Partition): ContinuousDataSourceRDDPartition = split match { + case p: ContinuousDataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a ContinuousDataSourceRDDPartition: $split") + } + + /** + * Initialize the shared reader for this partition if needed, then read rows from it until + * it returns null to signal the end of the epoch. + */ + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + + val readerForPartition = { + val partition = castPartition(split) + if (partition.queueReader == null) { + val partitionReader = partitionReaderFactory.createReader( + partition.inputPartition) + partition.queueReader = new ContinuousQueuedDataReader( + partition.index, partitionReader, schema, context, dataQueueSize, epochPollIntervalMs) + } + + partition.queueReader + } + + new NextIterator[InternalRow] { + override def getNext(): InternalRow = { + readerForPartition.next() match { + case null => + finished = true + null + case row => row + } + } + + override def close(): Unit = {} + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + castPartition(split).inputPartition.preferredLocations() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala deleted file mode 100644 index 06754f01657d3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.util.ThreadUtils - -class ContinuousDataSourceRDD( - sc: SparkContext, - sqlContext: SQLContext, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { - - private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize - private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs - - override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - // If attempt number isn't 0, this is a task retry, which we don't support. - if (context.attemptNumber() != 0) { - throw new ContinuousTaskRetryException() - } - - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] - .readerFactory.createDataReader() - - val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) - - // This queue contains two types of messages: - // * (null, null) representing an epoch boundary. - // * (row, off) containing a data row and its corresponding PartitionOffset. - val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) - - val epochPollFailed = new AtomicBoolean(false) - val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--$coordinatorId--${context.partitionId()}") - val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) - epochPollExecutor.scheduleWithFixedDelay( - epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - - // Important sequencing - we must get start offset before the data reader thread begins - val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset - - val dataReaderFailed = new AtomicBoolean(false) - val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) - dataReaderThread.setDaemon(true) - dataReaderThread.start() - - context.addTaskCompletionListener(_ => { - dataReaderThread.interrupt() - epochPollExecutor.shutdown() - }) - - val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) - new Iterator[UnsafeRow] { - private val POLL_TIMEOUT_MS = 1000 - - private var currentEntry: (UnsafeRow, PartitionOffset) = _ - private var currentOffset: PartitionOffset = startOffset - private var currentEpoch = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def hasNext(): Boolean = { - while (currentEntry == null) { - if (context.isInterrupted() || context.isCompleted()) { - currentEntry = (null, null) - } - if (dataReaderFailed.get()) { - throw new SparkException("data read failed", dataReaderThread.failureReason) - } - if (epochPollFailed.get()) { - throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) - } - currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) - } - - currentEntry match { - // epoch boundary marker - case (null, null) => - epochEndpoint.send(ReportPartitionOffset( - context.partitionId(), - currentEpoch, - currentOffset)) - currentEpoch += 1 - currentEntry = null - false - // real row - case (_, offset) => - currentOffset = offset - true - } - } - - override def next(): UnsafeRow = { - if (currentEntry == null) throw new NoSuchElementException("No current row was set") - val r = currentEntry._1 - currentEntry = null - r - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() - } -} - -case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset - -class EpochPollRunnable( - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread with Logging { - private[continuous] var failureReason: Throwable = _ - - private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def run(): Unit = { - try { - val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) - for (i <- currentEpoch to newEpoch - 1) { - queue.put((null, null)) - logDebug(s"Sent marker to start epoch ${i + 1}") - } - currentEpoch = newEpoch - } catch { - case t: Throwable => - failureReason = t - failedFlag.set(true) - throw t - } - } -} - -class DataReaderThread( - reader: DataReader[UnsafeRow], - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread( - s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { - private[continuous] var failureReason: Throwable = _ - - override def run(): Unit = { - TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) - try { - while (!context.isInterrupted && !context.isCompleted()) { - if (!reader.next()) { - // Check again, since reader.next() might have blocked through an incoming interrupt. - if (!context.isInterrupted && !context.isCompleted()) { - throw new IllegalStateException( - "Continuous reader reported no elements! Reader should have blocked waiting.") - } else { - return - } - } - - queue.put((reader.get().copy(), baseReader.getOffset)) - } - } catch { - case _: InterruptedException if context.isInterrupted() => - // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. - - case t: Throwable => - failureReason = t - failedFlag.set(true) - // Don't rethrow the exception in this thread. It's not needed, and the default Spark - // exception handler will kill the executor. - } finally { - reader.close() - } - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { - reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 951d694355ec5..4ddebb33b79d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,13 +29,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} class ContinuousExecution( @@ -43,7 +42,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamWriteSupport, + sink: StreamingWriteSupportProvider, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -53,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() + @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. @@ -63,7 +62,8 @@ class ContinuousExecution( val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( - source: ContinuousReadSupport, _, extraReaderOptions, output, _) => + source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => + // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? toExecutionRelationMap.getOrElseUpdate(r, { ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) @@ -122,16 +122,7 @@ class ContinuousExecution( s"Batch $latestEpochId was committed without end epoch offsets!") } committedOffsets = nextOffsets.toStreamProgress(sources) - - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets @@ -157,8 +148,7 @@ class ContinuousExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - dataSource.createContinuousReader( - java.util.Optional.empty[StructType](), + dataSource.createContinuousReadSupport( metadataPath, new DataSourceOptions(extraReaderOptions.asJava)) } @@ -169,9 +159,9 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { case ContinuousExecutionRelation(source, options, output) => - val reader = continuousSources(insertedSourceId) + val readSupport = continuousSources(insertedSourceId) insertedSourceId += 1 - val newOutput = reader.readSchema().toAttributes + val newOutput = readSupport.fullSchema().toAttributes assert(output.size == newOutput.size, s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + @@ -179,9 +169,10 @@ class ContinuousExecution( replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) - val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - StreamingDataSourceV2Relation(newOutput, source, options, reader) + val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) + val startOffset = realOffset.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) + StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) } // Rewire the plan to use the new attributes that were returned by the source. @@ -194,16 +185,12 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamWriter( + val writer = sink.createStreamingWriteSupport( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) - - val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r - }.head + val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( @@ -217,6 +204,11 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } + val (readSupport, scanConfig) = lastExecution.executedPlan.collect { + case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => + scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig + }.head + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) // Add another random ID on top of the run ID, to distinguish epoch coordinators across @@ -225,18 +217,23 @@ class ContinuousExecution( currentEpochCoordinatorId = epochCoordinatorId sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_INTERVAL_KEY, + trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) && + state.compareAndSet(ACTIVE, RECONFIGURING) + if (shouldReconfigure) { if (queryExecutionThread.isAlive) { queryExecutionThread.interrupt() } @@ -286,10 +283,12 @@ class ContinuousExecution( * Report ending partition offsets for the given reader at the given epoch. */ def addOffset( - epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { + epoch: Long, + readSupport: ContinuousReadSupport, + partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray) val oldOffset = synchronized { offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) offsetLog.get(epoch - 1) @@ -315,9 +314,12 @@ class ContinuousExecution( def commit(epoch: Long): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") + synchronized { + // Record offsets before updating `committedOffsets` + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) if (queryExecutionThread.isAlive) { - commitLog.add(epoch) + commitLog.add(epoch, CommitMetadata()) val offset = continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) @@ -327,9 +329,14 @@ class ContinuousExecution( } } - if (minLogEntriesToMaintain < currentBatchId) { - offsetLog.purge(currentBatchId - minLogEntriesToMaintain) - commitLog.purge(currentBatchId - minLogEntriesToMaintain) + // Since currentBatchId increases independently in cp mode, the current committed epoch may + // be far behind currentBatchId. It is not safe to discard the metadata with thresholdBatchId + // computed based on currentBatchId. As minLogEntriesToMaintain is used to keep the minimum + // number of batches that must be retained and made recoverable, so we should keep the + // specified number of metadata that have been committed. + if (minLogEntriesToMaintain <= epoch) { + offsetLog.purge(epoch + 1 - minLogEntriesToMaintain) + commitLog.purge(epoch + 1 - minLogEntriesToMaintain) } awaitProgressLock.lock() @@ -365,9 +372,26 @@ class ContinuousExecution( } } } + + /** + * Stops the query execution thread to terminate the query. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + // The query execution thread will clean itself up in the finally clause of runContinuous. + // We just need to interrupt the long running job. + queryExecutionThread.interrupt() + queryExecutionThread.join() + } + logInfo(s"Query $prettyIdString was stopped") + } } object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" + val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala new file mode 100644 index 0000000000000..65c5fc63c2f46 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.Closeable +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ThreadUtils + +/** + * A wrapper for a continuous processing data reader, including a reading queue and epoch markers. + * + * This will be instantiated once per partition - successive calls to compute() in the + * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of + * offsets across epochs. Each compute() should call the next() method here until null is returned. + */ +class ContinuousQueuedDataReader( + partitionIndex: Int, + reader: ContinuousPartitionReader[InternalRow], + schema: StructType, + context: TaskContext, + dataQueueSize: Int, + epochPollIntervalMs: Long) extends Closeable { + // Important sequencing - we must get our starting point before the provider threads start running + private var currentOffset: PartitionOffset = reader.getOffset + + /** + * The record types in the read buffer. + */ + sealed trait ContinuousRecord + case object EpochMarker extends ContinuousRecord + case class ContinuousRow(row: InternalRow, offset: PartitionOffset) extends ContinuousRecord + + private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) + + private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + + private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--$coordinatorId--${context.partitionId()}") + private val epochMarkerGenerator = new EpochMarkerGenerator + epochMarkerExecutor.scheduleWithFixedDelay( + epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + private val dataReaderThread = new DataReaderThread(schema) + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener[Unit](_ => { + this.close() + }) + + private def shouldStop() = { + context.isInterrupted() || context.isCompleted() + } + + /** + * Return the next row to be read in the current epoch, or null if the epoch is done. + * + * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch + * will call next() again to start getting rows. + */ + def next(): InternalRow = { + val POLL_TIMEOUT_MS = 1000 + var currentEntry: ContinuousRecord = null + + while (currentEntry == null) { + if (shouldStop()) { + // Force the epoch to end here. The writer will notice the context is interrupted + // or completed and not start a new one. This makes it possible to achieve clean + // shutdown of the streaming query. + // TODO: The obvious generalization of this logic to multiple stages won't work. It's + // invalid to send an epoch marker from the bottom of a task if all its child tasks + // haven't sent one. + currentEntry = EpochMarker + } else { + if (dataReaderThread.failureReason != null) { + throw new SparkException("Data read failed", dataReaderThread.failureReason) + } + if (epochMarkerGenerator.failureReason != null) { + throw new SparkException( + "Epoch marker generation failed", + epochMarkerGenerator.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + } + + currentEntry match { + case EpochMarker => + epochCoordEndpoint.send(ReportPartitionOffset( + partitionIndex, EpochTracker.getCurrentEpoch.get, currentOffset)) + null + case ContinuousRow(row, offset) => + currentOffset = offset + row + } + } + + override def close(): Unit = { + dataReaderThread.interrupt() + epochMarkerExecutor.shutdown() + } + + /** + * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when + * a new row arrives to the [[ContinuousPartitionReader]]. + */ + class DataReaderThread(schema: StructType) extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + private val toUnsafe = UnsafeProjection.create(schema) + + override def run(): Unit = { + TaskContext.setTaskContext(context) + try { + while (!shouldStop()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!shouldStop()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + // `InternalRow#copy` may not be properly implemented, for safety we convert to unsafe row + // before copy here. + queue.put(ContinuousRow(toUnsafe(reader.get()).copy(), reader.getOffset)) + } + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + logInfo(s"shutting down interrupted data reader thread $getName") + + case NonFatal(t) => + failureReason = t + logWarning("data reader thread failed", t) + // If we throw from this thread, we may kill the executor. Let the parent thread handle + // it. + + case t: Throwable => + failureReason = t + throw t + } finally { + reader.close() + } + } + } + + /** + * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with + * EpochMarker when a new epoch marker arrives. + */ + class EpochMarkerGenerator extends Runnable with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That + // field represents the epoch wrt the data being processed. The currentEpoch here is just a + // counter to ensure we send the appropriate number of markers if we fall behind the driver. + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch) + // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking + // a while. We catch up by injecting enough epoch markers immediately to catch up. This will + // result in some epochs being empty for this partition, but that's fine. + for (i <- currentEpoch to newEpoch - 1) { + queue.put(EpochMarker) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + throw t + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..a6cde2b8a710f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -17,25 +17,22 @@ package org.apache.spark.sql.execution.streaming.continuous -import scala.collection.JavaConverters._ - import org.json4s.DefaultFormats import org.json4s.jackson.Serialization -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceOptions) - extends ContinuousReader { +class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -57,18 +54,18 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA - - private var offset: Offset = _ + override def fullSchema(): StructType = RateStreamProvider.SCHEMA - override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) } - override def getStartOffset(): Offset = offset + override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val partitionStartMap = offset match { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start + + val partitionStartMap = startOffset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => throw new IllegalArgumentException( @@ -85,14 +82,18 @@ class RateStreamContinuousReader(options: DataSourceOptions) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousDataReaderFactory( + RateStreamContinuousInputPartition( start.value, start.runTimeMs, i, numPartitions, perPartitionRate) - .asInstanceOf[DataReaderFactory[Row]] - }.asJava + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + RateStreamContinuousReaderFactory } override def commit(end: Offset): Unit = {} @@ -113,43 +114,34 @@ class RateStreamContinuousReader(options: DataSourceOptions) } -case class RateStreamContinuousDataReaderFactory( +case class RateStreamContinuousInputPartition( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReaderFactory[Row] { - - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { - val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] - require(rateStreamOffset.partition == partitionIndex, - s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousDataReader( - rateStreamOffset.currentValue, - rateStreamOffset.currentTimeMs, - partitionIndex, - increment, - rowsPerSecond) - } + extends InputPartition - override def createDataReader(): DataReader[Row] = - new RateStreamContinuousDataReader( - startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) +object RateStreamContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamContinuousInputPartition] + new RateStreamContinuousPartitionReader( + p.startValue, p.startTimeMs, p.partitionIndex, p.increment, p.rowsPerSecond) + } } -class RateStreamContinuousDataReader( +class RateStreamContinuousPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReader[Row] { + extends ContinuousPartitionReader[InternalRow] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong private var currentValue = startValue - private var currentRow: Row = null + private var currentRow: InternalRow = null override def next(): Boolean = { currentValue += increment @@ -165,14 +157,14 @@ class RateStreamContinuousDataReader( return false } - currentRow = Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(nextReadTime)), + currentRow = InternalRow( + DateTimeUtils.fromMillis(nextReadTime), currentValue) true } - override def get: Row = currentRow + override def get: InternalRow = currentRow override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala new file mode 100644 index 0000000000000..28ab2448a6633 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.net.Socket +import java.sql.Timestamp +import java.util.Calendar +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable.ListBuffer + +import org.json4s.{DefaultFormats, NoTypeHints} +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.execution.streaming.sources.TextSocketReader +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.RpcUtils + + +/** + * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This ContinuousReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. + * + * The driver maintains a socket connection to the host-port, keeps the received messages in + * buckets and serves the messages to the executors via a RPC endpoint. + */ +class TextSocketContinuousReadSupport(options: DataSourceOptions) + extends ContinuousReadSupport with Logging { + + implicit val defaultFormats: DefaultFormats = DefaultFormats + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + private val numPartitions = spark.sparkContext.defaultParallelism + + @GuardedBy("this") + private var socket: Socket = _ + + @GuardedBy("this") + private var readThread: Thread = _ + + @GuardedBy("this") + private val buckets = Seq.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + + @GuardedBy("this") + private var currentOffset: Int = -1 + + // Exposed for tests. + private[spark] var startOffset: TextSocketOffset = _ + + private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) + @volatile private var endpointRef: RpcEndpointRef = _ + + initialize() + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + assert(offsets.length == numPartitions) + val offs = offsets + .map(_.asInstanceOf[ContinuousRecordPartitionOffset]) + .sortBy(_.partitionId) + .map(_.offset) + .toList + TextSocketOffset(offs) + } + + override def deserializeOffset(json: String): Offset = { + TextSocketOffset(Serialization.read[List[Int]](json)) + } + + override def initialOffset(): Offset = { + startOffset = TextSocketOffset(List.fill(numPartitions)(0)) + startOffset + } + + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } + + override def fullSchema(): StructType = { + if (includeTimestamp) { + TextSocketReader.SCHEMA_TIMESTAMP + } else { + TextSocketReader.SCHEMA_REGULAR + } + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) + val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" + endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + val offsets = startOffset match { + case off: TextSocketOffset => off.offsets + case off => + throw new IllegalArgumentException( + s"invalid offset type ${off.getClass} for TextSocketContinuousReader") + } + + if (offsets.size != numPartitions) { + throw new IllegalArgumentException( + s"The previous run contained ${offsets.size} partitions, but" + + s" $numPartitions partitions are currently configured. The numPartitions option" + + " cannot be changed.") + } + + startOffset.offsets.zipWithIndex.map { + case (offset, i) => + TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp) + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + TextSocketReaderFactory + } + + override def commit(end: Offset): Unit = synchronized { + val endOffset = end match { + case off: TextSocketOffset => off + case _ => throw new IllegalArgumentException(s"TextSocketContinuousReader.commit()" + + s"received an offset ($end) that did not originate with an instance of this class") + } + + endOffset.offsets.zipWithIndex.foreach { + case (offset, partition) => + val max = startOffset.offsets(partition) + buckets(partition).size + if (offset > max) { + throw new IllegalStateException("Invalid offset " + offset + " to commit" + + " for partition " + partition + ". Max valid offset: " + max) + } + val n = offset - startOffset.offsets(partition) + buckets(partition).trimStart(n) + } + startOffset = endOffset + recordEndpoint.setStartOffsets(startOffset.offsets) + } + + /** Stop this source. */ + override def stop(): Unit = synchronized { + if (socket != null) { + try { + // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to + // stop the readThread is to close the socket. + socket.close() + } catch { + case e: IOException => + } + socket = null + } + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + private def initialize(): Unit = synchronized { + socket = new Socket(host, port) + val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) + // Thread continuously reads from a socket and inserts data into buckets + readThread = new Thread(s"TextSocketContinuousReader($host, $port)") { + setDaemon(true) + + override def run(): Unit = { + try { + while (true) { + val line = reader.readLine() + if (line == null) { + // End of file reached + logWarning(s"Stream closed by $host:$port") + return + } + TextSocketContinuousReadSupport.this.synchronized { + currentOffset += 1 + val newData = (line, + Timestamp.valueOf( + TextSocketReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + buckets(currentOffset % numPartitions) += newData + } + } + } catch { + case e: IOException => + } + } + } + + readThread.start() + } + + override def toString: String = s"TextSocketContinuousReader[host: $host, port: $port]" + + private def includeTimestamp: Boolean = options.getBoolean("includeTimestamp", false) + +} + +/** + * Continuous text socket input partition. + */ +case class TextSocketContinuousInputPartition( + driverEndpointName: String, + partitionId: Int, + startOffset: Int, + includeTimestamp: Boolean) extends InputPartition + + +object TextSocketReaderFactory extends ContinuousPartitionReaderFactory { + + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[TextSocketContinuousInputPartition] + new TextSocketContinuousPartitionReader( + p.driverEndpointName, p.partitionId, p.startOffset, p.includeTimestamp) + } +} + + +/** + * Continuous text socket input partition reader. + * + * Polls the driver endpoint for new records. + */ +class TextSocketContinuousPartitionReader( + driverEndpointName: String, + partitionId: Int, + startOffset: Int, + includeTimestamp: Boolean) + extends ContinuousPartitionReader[InternalRow] { + + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[InternalRow] = None + + override def next(): Boolean = { + try { + current = getRecord + while (current.isEmpty) { + Thread.sleep(100) + current = getRecord + } + currentOffset += 1 + } catch { + case _: InterruptedException => + // Someone's trying to end the task; just let them. + return false + } + true + } + + override def get(): InternalRow = { + current.get + } + + override def close(): Unit = {} + + override def getOffset: PartitionOffset = + ContinuousRecordPartitionOffset(partitionId, currentOffset) + + private def getRecord: Option[InternalRow] = + endpoint.askSync[Option[InternalRow]](GetRecord( + ContinuousRecordPartitionOffset(partitionId, currentOffset))).map(rec => + if (includeTimestamp) { + rec + } else { + InternalRow(rec.get(0, TextSocketReader.SCHEMA_TIMESTAMP) + .asInstanceOf[(String, Timestamp)]._1) + } + ) +} + +case class TextSocketOffset(offsets: List[Int]) extends Offset { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json: String = Serialization.write(offsets) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala new file mode 100644 index 0000000000000..a08411d746abe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.DataWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory +import org.apache.spark.util.Utils + +/** + * The RDD writing to a sink in continuous processing. + * + * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data + * to be written for one epoch, which we commit and forward to the driver. + * + * We keep repeating prev.compute() and writing new epochs until the query is shut down. + */ +class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory) + extends RDD[Unit](prev) { + + override val partitioner = prev.partitioner + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + while (!context.isInterrupted() && !context.isCompleted()) { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + val dataIterator = prev.compute(split, context) + dataWriter = writerFactory.createWriter( + context.partitionId(), + context.taskAttemptId(), + EpochTracker.getCurrentEpoch.get) + while (dataIterator.hasNext) { + dataWriter.write(dataIterator.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") + val msg = dataWriter.commit() + epochCoordinator.send( + CommitPartitionEpoch( + context.partitionId(), + EpochTracker.getCurrentEpoch.get, + msg) + ) + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.") + EpochTracker.incrementCurrentEpoch() + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so compute() will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } + + Iterator() + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index cc6808065c0cd..2238ce26e7b46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,15 +82,15 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( - writer, reader, query, startEpoch, session, env.rpcEnv) + writeSupport, readSupport, query, startEpoch, session, env.rpcEnv) val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref @@ -115,8 +115,8 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, startEpoch: Long, session: SparkSession, @@ -137,30 +137,71 @@ private[continuous] class EpochCoordinator( private val partitionOffsets = mutable.Map[(Long, Int), PartitionOffset]() + private var lastCommittedEpoch = startEpoch - 1 + // Remembers epochs that have to wait for previous epochs to be committed first. + private val epochsWaitingToBeCommitted = mutable.HashSet.empty[Long] + private def resolveCommitsAtEpoch(epoch: Long) = { - val thisEpochCommits = - partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + val thisEpochCommits = findPartitionCommitsForEpoch(epoch) val nextEpochOffsets = partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochCommits.size == numWriterPartitions && nextEpochOffsets.size == numReaderPartitions) { - logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.") - // Sequencing is important here. We must commit to the writer before recording the commit - // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, thisEpochCommits.toArray) - query.commit(epoch) - - // Cleanup state from before this epoch, now that we know all partitions are forever past it. - for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) - } - for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + + // Check that last committed epoch is the previous one for sequencing of committed epochs. + // If not, add the epoch being currently processed to epochs waiting to be committed, + // otherwise commit it. + if (lastCommittedEpoch != epoch - 1) { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is waiting for epoch ${epoch - 1} to be committed first.") + epochsWaitingToBeCommitted.add(epoch) + } else { + commitEpoch(epoch, thisEpochCommits) + lastCommittedEpoch = epoch + + // Commit subsequent epochs that are waiting to be committed. + var nextEpoch = lastCommittedEpoch + 1 + while (epochsWaitingToBeCommitted.contains(nextEpoch)) { + val nextEpochCommits = findPartitionCommitsForEpoch(nextEpoch) + commitEpoch(nextEpoch, nextEpochCommits) + + epochsWaitingToBeCommitted.remove(nextEpoch) + lastCommittedEpoch = nextEpoch + nextEpoch += 1 + } + + // Cleanup state from before last committed epoch, + // now that we know all partitions are forever past it. + for (k <- partitionCommits.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionCommits.remove(k) + } + for (k <- partitionOffsets.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionOffsets.remove(k) + } } } } + /** + * Collect per-partition commits for an epoch. + */ + private def findPartitionCommitsForEpoch(epoch: Long): Iterable[WriterCommitMessage] = { + partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + } + + /** + * Commit epoch to the offset log. + */ + private def commitEpoch(epoch: Long, messages: Iterable[WriterCommitMessage]): Unit = { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is ready to be committed. Committing epoch $epoch.") + // Sequencing is important here. We must commit to the writer before recording the commit + // in the query, or we will end up dropping the commit if we restart in the middle. + writeSupport.commit(epoch, messages.toArray) + query.commit(epoch) + } + override def receive: PartialFunction[Any, Unit] = { // If we just drop these messages, we won't do any writes to the query. The lame duck tasks // won't shed errors or anything. @@ -179,7 +220,7 @@ private[continuous] class EpochCoordinator( partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochOffsets.size == numReaderPartitions) { logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") - query.addOffset(epoch, reader, thisEpochOffsets.toSeq) + query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala new file mode 100644 index 0000000000000..bc0ae428d4521 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +/** + * Tracks the current continuous processing epoch within a task. Call + * EpochTracker.getCurrentEpoch to get the current epoch. + */ +object EpochTracker { + // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will + // update the underlying AtomicLong as it finishes epochs. Other code should only read the value. + private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] { + override def initialValue() = new AtomicLong(-1) + } + + /** + * Get the current epoch for the current task, or None if the task has no current epoch. + */ + def getCurrentEpoch: Option[Long] = { + currentEpoch.get().get() match { + case n if n < 0 => None + case e => Some(e) + } + } + + /** + * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * between epochs. + */ + def incrementCurrentEpoch(): Unit = { + currentEpoch.get().incrementAndGet() + } + + /** + * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * at the beginning of a task. + */ + def initializeCurrentEpoch(startEpoch: Long): Unit = { + currentEpoch.get().set(startEpoch) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala new file mode 100644 index 0000000000000..7ad21cc304e7c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport + +/** + * The logical plan for writing data in a continuous stream. + */ +case class WriteToContinuousDataSource( + writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala new file mode 100644 index 0000000000000..c216b61383856 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import scala.util.control.NonFatal + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport + +/** + * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. + */ +case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) + extends SparkPlan with Logging { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writerFactory = writeSupport.createStreamingWriterFactory() + val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) + + logInfo(s"Start processing data source write support: $writeSupport. " + + s"The input RDD has ${rdd.partitions.length} partitions.") + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + try { + // Force the RDD to run so continuous processing starts; no data is actually being collected + // to the driver, as ContinuousWriteRDD outputs nothing. + rdd.collect() + } catch { + case _: InterruptedException => + // Interruption is how continuous queries are ended, so accept and ignore the exception. + case cause: Throwable => + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } + } + + sparkContext.emptyRDD + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..9b13f6398d837 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcAddress +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener[Unit] { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. + * + * @param sc the RDD context + * @param numPartitions the number of read partitions for this RDD + * @param queueSize the size of the row buffers to use + * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD + * @param epochIntervalMs the checkpoint interval of the streaming query + */ +class ContinuousShuffleReadRDD( + sc: SparkContext, + numPartitions: Int, + queueSize: Int = 1024, + numShuffleWriters: Int = 1, + epochIntervalMs: Long = 1000, + val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}")) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition( + partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..42631c90ebc55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..47b1f78b24505 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for writing to a continuous processing shuffle. + */ +trait ContinuousShuffleWriter { + def write(epoch: Iterator[UnsafeRow]): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala new file mode 100644 index 0000000000000..502ae0d4822e8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +/** + * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker. + * + * Each message comes tagged with writerId, identifying which writer the message is coming + * from. The receiver will only begin the next epoch once all writers have sent an epoch + * marker ending the current epoch. + */ +private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable { + def writerId: Int +} +private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) + extends RPCContinuousShuffleMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. + */ +private[continuous] class RPCContinuousShuffleReader( + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long, + override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. + private val queues = Array.fill(numShuffleWriters) { + new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize) + } + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: RPCContinuousShuffleMessage => + // Note that this will block a thread the shared RPC handler pool! + // The TCP based shuffle handler (SPARK-24541) will avoid this problem. + queues(r.writerId).put(r) + context.reply(()) + } + + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + // An array of flags for whether each writer ID has gotten an epoch marker. + private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) + + private val executor = Executors.newFixedThreadPool(numShuffleWriters) + private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor) + + private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] { + override def call(): RPCContinuousShuffleMessage = queues(writerId).take() + } + + // Initialize by submitting tasks to read the first row from each writer. + (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) + + /** + * In each call to getNext(), we pull the next row available in the completion queue, and then + * submit another task to read the next row from the writer which returned it. + * + * When a writer sends an epoch marker, we note that it's finished and don't submit another + * task for it in this epoch. The iterator is over once all writers have sent an epoch marker. + */ + override def getNext(): UnsafeRow = { + var nextRow: UnsafeRow = null + while (!finished && nextRow == null) { + completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { + case null => + // Try again if the poll didn't wait long enough to get a real result. + // But we should be getting at least an epoch marker every checkpoint interval. + val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect { + case (flag, idx) if !flag => idx + } + logWarning( + s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + + s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") + + // The completion service guarantees this future will be available immediately. + case future => future.get() match { + case ReceiverRow(writerId, r) => + // Start reading the next element in the queue we just took from. + completion.submit(completionTask(writerId)) + nextRow = r + case ReceiverEpochMarker(writerId) => + // Don't read any more from this queue. If all the writers have sent epoch markers, + // the epoch is over; otherwise we need to loop again to poll from the remaining + // writers. + writerEpochMarkersReceived(writerId) = true + if (writerEpochMarkersReceived.forall(_ == true)) { + finished = true + } + } + } + } + + nextRow + } + + override def close(): Unit = { + executor.shutdownNow() + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..1c6f3ddb395e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import scala.concurrent.Future +import scala.concurrent.duration.Duration + +import org.apache.spark.Partitioner +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.ThreadUtils + +/** + * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. + * + * @param writerId The partition ID of this writer. + * @param outputPartitioner The partitioner on the reader side of the shuffle. + * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by + * partition ID within outputPartitioner. + */ +class RPCContinuousShuffleWriter( + writerId: Int, + outputPartitioner: Partitioner, + endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { + + if (outputPartitioner.numPartitions != 1) { + throw new IllegalArgumentException("multiple readers not yet supported") + } + + if (outputPartitioner.numPartitions != endpoints.length) { + throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + + s"not match endpoint count ${endpoints.length}") + } + + def write(epoch: Iterator[UnsafeRow]): Unit = { + while (epoch.hasNext) { + val row = epoch.next() + endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) + } + + val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq + implicit val ec = ThreadUtils.sameThread + ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 628923d367ce7..adf52aba21a04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,30 +17,26 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils - object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) @@ -68,7 +64,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas addData(data.toTraversable) } - def readSchema(): StructType = encoder.schema + def fullSchema(): StructType = encoder.schema protected def logicalPlan: LogicalPlan @@ -81,8 +77,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) - with MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -124,24 +119,22 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] - endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] - } - } - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) - override def getStartOffset: OffsetV2 = synchronized { - if (startOffset.offset == -1) null else startOffset + override def initialOffset: OffsetV2 = LongOffset(-1) + + override def latestOffset(): OffsetV2 = { + if (currentOffset.offset == -1) null else currentOffset } - override def getEndOffset: OffsetV2 = synchronized { - if (endOffset.offset == -1) null else endOffset + override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOffset = sc.start.asInstanceOf[LongOffset] + val endOffset = sc.end.get.asInstanceOf[LongOffset] synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -158,11 +151,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] - }.asJava + new MemoryStreamInputPartition(block) + }.toArray } } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + MemoryStreamReaderFactory + } + private def generateDebugString( rows: Seq[UnsafeRow], startOrdinal: Int, @@ -203,10 +200,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) - extends DataReaderFactory[UnsafeRow] { - override def createDataReader(): DataReader[UnsafeRow] = { - new DataReader[UnsafeRow] { +class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition + +object MemoryStreamReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val records = partition.asInstanceOf[MemoryStreamInputPartition].records + new PartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { @@ -222,11 +221,20 @@ class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) } } +/** A common trait for MemorySinks with methods used for testing */ +trait MemorySinkBase extends BaseStreamingSink { + def allData: Seq[Row] + def latestBatchData: Seq[Row] + def dataSinceBatch(sinceBatchId: Long): Seq[Row] + def latestBatchId: Option[Long] +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -236,7 +244,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.map(_.data).flatten + batches.flatMap(_.data) } def latestBatchId: Option[Long] = synchronized { @@ -245,6 +253,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { @@ -294,7 +306,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala similarity index 80% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala index d276403190b3c..833e62f35ede1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.execution.streaming.sources -import scala.collection.JavaConverters._ - import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriter(schema: StructType, options: DataSourceOptions) - extends StreamWriter with Logging { +class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) + extends StreamingWriteSupport with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) @@ -39,7 +38,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 @@ -62,8 +61,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark - .createDataFrame(rows.toList.asJava, schema) + Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) .show(numRowsToShow, isTruncated) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index c28919b8b729b..dbcc4483e5770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -17,26 +17,22 @@ package org.apache.spark.sql.execution.streaming.sources -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization -import org.apache.spark.SparkEnv -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} -import org.apache.spark.sql.{Encoder, Row, SQLContext} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.{Encoder, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.util.RpcUtils /** @@ -44,13 +40,14 @@ import org.apache.spark.util.RpcUtils * * ContinuousMemoryStream maintains a list of records for each partition. addData() will * distribute records evenly-ish across partitions. * * RecordEndpoint is set up as an endpoint for executor-side - * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified - * offset within the list, or null if that offset doesn't yet have a record. + * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at + * the specified offset within the list, or null if that offset doesn't yet have a record. */ -class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) + extends MemoryStreamBase[A](sqlContext) + with ContinuousReadSupportProvider with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) - private val NUM_PARTITIONS = 2 protected val logicalPlan = StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) @@ -58,33 +55,23 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) // ContinuousReader implementation @GuardedBy("this") - private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) - - @GuardedBy("this") - private var startOffset: ContinuousMemoryStreamOffset = _ + private val records = Seq.fill(numPartitions)(new ListBuffer[A]) - private val recordEndpoint = new RecordEndpoint() + private val recordEndpoint = new ContinuousRecordEndpoint(records, this) @volatile private var endpointRef: RpcEndpointRef = _ def addData(data: TraversableOnce[A]): Offset = synchronized { // Distribute data evenly among partition lists. data.toSeq.zipWithIndex.map { - case (item, index) => records(index % NUM_PARTITIONS) += item + case (item, index) => records(index % numPartitions) += item } // The new target offset is the offset where all records in all partitions have been processed. - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } - override def setStartOffset(start: Optional[Offset]): Unit = synchronized { - // Inferred initial offset is position 0 in each partition. - startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) - }.asInstanceOf[ContinuousMemoryStreamOffset] - } - - override def getStartOffset: Offset = synchronized { - startOffset + override def initialOffset(): Offset = { + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) } override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { @@ -94,111 +81,117 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { ContinuousMemoryStreamOffset( offsets.map { - case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + case ContinuousRecordPartitionOffset(part, num) => (part, num) }.toMap ) } - override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[ContinuousMemoryStreamOffset] synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) startOffset.partitionNums.map { - case (part, index) => - new ContinuousMemoryStreamDataReaderFactory( - endpointName, part, index): DataReaderFactory[Row] - }.toList.asJava + case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index) + }.toArray } } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + ContinuousMemoryStreamReaderFactory + } + override def stop(): Unit = { if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) } override def commit(end: Offset): Unit = {} - // ContinuousReadSupport implementation + // ContinuousReadSupportProvider implementation // This is necessary because of how StreamTest finds the source for AddDataMemory steps. - def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - this - } - - /** - * Endpoint for executors to poll for records. - */ - private class RecordEndpoint extends ThreadSafeRpcEndpoint { - override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => - ContinuousMemoryStream.this.synchronized { - val buf = records(part) - val record = if (buf.size <= index) None else Some(buf(index)) - - context.reply(record.map(Row(_))) - } - } - } + options: DataSourceOptions): ContinuousReadSupport = this } object ContinuousMemoryStream { - case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset) protected val memoryStreamId = new AtomicInteger(0) def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + + def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) } /** - * Data reader factory for continuous memory stream. + * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamDataReaderFactory( +case class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends DataReaderFactory[Row] { - override def createDataReader: ContinuousMemoryStreamDataReader = - new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition + +object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[ContinuousMemoryStreamInputPartition] + new ContinuousMemoryStreamPartitionReader(p.driverEndpointName, p.partition, p.startOffset) + } } /** - * Data reader for continuous memory stream. + * An input partition reader for continuous memory stream. * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamDataReader( +class ContinuousMemoryStreamPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousDataReader[Row] { + startOffset: Int) extends ContinuousPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, SparkEnv.get.rpcEnv) private var currentOffset = startOffset - private var current: Option[Row] = None + private var current: Option[InternalRow] = None + + // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, + // we have to do a bit of error prone work to get it into every thread used by continuous + // processing. We hope that some unit test will end up instantiating a continuous memory stream + // in such cases. + if (TaskContext.get() == null) { + throw new IllegalStateException("Task context was not set!") + } override def next(): Boolean = { - current = None + current = getRecord while (current.isEmpty) { Thread.sleep(10) - current = endpoint.askSync[Option[Row]]( - GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + current = getRecord } currentOffset += 1 true } - override def get(): Row = current.get + override def get(): InternalRow = current.get override def close(): Unit = {} - override def getOffset: ContinuousMemoryStreamPartitionOffset = - ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + override def getOffset: ContinuousRecordPartitionOffset = + ContinuousRecordPartitionOffset(partition, currentOffset) + + private def getRecord: Option[InternalRow] = + endpoint.askSync[Option[InternalRow]]( + GetRecord(ContinuousRecordPartitionOffset(partition, currentOffset))) } case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) @@ -206,6 +199,3 @@ case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) private implicit val formats = Serialization.formats(NoTypeHints) override def json(): String = Serialization.write(partitionNums) } - -case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int) - extends PartitionOffset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala new file mode 100644 index 0000000000000..03c567c58d46a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.api.python.PythonException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.streaming.DataStreamWriter + +class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T]) + extends Sink { + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + val resolvedEncoder = encoder.resolveAndBind( + data.logicalPlan.output, + data.sparkSession.sessionState.analyzer) + val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag) + val ds = data.sparkSession.createDataset(rdd)(encoder) + batchWriter(ds, batchId) + } + + override def toString(): String = "ForeachBatchSink" +} + + +/** + * Interface that is meant to be extended by Python classes via Py4J. + * Py4J allows Python classes to implement Java interfaces so that the JVM can call back + * Python objects. In this case, this allows the user-defined Python `foreachBatch` function + * to be called from JVM when the query is active. + * */ +trait PythonForeachBatchFunction { + /** Call the Python implementation of this function */ + def call(batchDF: DataFrame, batchId: Long): Unit +} + +object PythonForeachBatchHelper { + def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = { + dsw.foreachBatch(pythonFunc.call _) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala similarity index 54% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala index df5d69d57e36f..4218fd51ad206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -31,23 +33,35 @@ import org.apache.spark.sql.types.StructType * [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { - override def createStreamWriter( +case class ForeachWriteSupportProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) + extends StreamingWriteSupportProvider { + + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new StreamWriter with SupportsWriteInternalRow { + options: DataSourceOptions): StreamingWriteSupport = { + new StreamingWriteSupport { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - val encoder = encoderFor[T].resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - ForeachWriterFactory(writer, encoder) + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) } override def toString: String = "ForeachSink" @@ -55,29 +69,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S } } -case class ForeachWriterFactory[T: Encoder]( +object ForeachWriteSupportProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriteSupportProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriteSupportProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]) - extends DataWriterFactory[InternalRow] { - override def createDataWriter( + rowConverter: InternalRow => T) + extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, encoder, partitionId, epochId) + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } } /** * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * * @param writer The [[ForeachWriter]] to process all data. - * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] * @param partitionId * @param epochId * @tparam T The type expected by the writer. */ -class ForeachDataWriter[T : Encoder]( +class ForeachDataWriter[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T], + rowConverter: InternalRow => T, partitionId: Int, epochId: Long) extends DataWriter[InternalRow] { @@ -89,7 +118,7 @@ class ForeachDataWriter[T : Encoder]( if (!opened) return try { - writer.process(encoder.fromRow(record)) + writer.process(rowConverter(record)) } catch { case t: Throwable => writer.close(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala new file mode 100644 index 0000000000000..9f88416871f8e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} + +/** + * A [[BatchWriteSupport]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped + * streaming write support. + */ +class MicroBatchWritSupport(eppchId: Long, val writeSupport: StreamingWriteSupport) + extends BatchWriteSupport { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.commit(eppchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.abort(eppchId, messages) + } + + override def createBatchWriterFactory(): DataWriterFactory = { + new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory()) + } +} + +class MicroBatchWriterFactory(epochId: Long, streamingWriterFactory: StreamingDataWriterFactory) + extends DataWriterFactory { + + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + streamingWriterFactory.createWriter(partitionId, taskId, epochId) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala deleted file mode 100644 index 56f7ff25cbed0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter - -/** - * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements - * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped - * streaming writer. - */ -class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - - override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() -} - -class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) - extends DataSourceWriter with SupportsWriteInternalRow { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = - writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => throw new IllegalStateException( - "InternalRowMicroBatchWriter should only be created with base writer support") - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index e07355aa37dba..ac3c71cc222b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -20,21 +20,22 @@ package org.apache.spark.sql.execution.streaming.sources import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.Row -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[DataSourceWriter]] on the driver. + * to a [[BatchWriteSupport]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[Row] { - override def createDataWriter( +case object PackedRowWriterFactory extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, - attemptNumber: Int, - epochId: Long): DataWriter[Row] = { + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { new PackedRowDataWriter() } } @@ -43,15 +44,16 @@ case object PackedRowWriterFactory extends DataWriterFactory[Row] { * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most * recent interval. */ -case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage +case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommitMessage /** * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. */ -class PackedRowDataWriter() extends DataWriter[Row] with Logging { - private val data = mutable.Buffer[Row]() +class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging { + private val data = mutable.Buffer[InternalRow]() - override def write(row: Row): Unit = data.append(row) + // Spark reuses the same `InternalRow` instance, here we copy it before buffer it. + override def write(row: InternalRow): Unit = data.append(row.copy()) override def commit(): PackedRowCommitMessage = { val msg = PackedRowCommitMessage(data.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala new file mode 100644 index 0000000000000..90680ea38fbd6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} + +// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. +trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { + + override def latestOffset(): Offset = { + throw new IllegalAccessException( + "latestOffset should not be called for RateControlMicroBatchReadSupport") + } + + def latestOffset(start: Offset): Offset +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala similarity index 75% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala index 6cf8520fc544f..f5364047adff1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala @@ -19,26 +19,24 @@ package org.apache.spark.sql.execution.streaming.sources import java.io._ import java.nio.charset.StandardCharsets -import java.util.Optional import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ - import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { +class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReadSupport with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -105,38 +103,30 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: @volatile private var lastTimeMs: Long = creationTimeMs - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA + override def initialOffset(): Offset = LongOffset(0L) - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end + override def latestOffset(): Offset = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) } override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + override def fullSchema(): StructType = SCHEMA + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startSeconds = sc.start.asInstanceOf[LongOffset].offset + val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") if (endSeconds > maxSeconds) { throw new ArithmeticException("Integer overflow. Max offset with " + @@ -152,7 +142,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") if (rangeStart == rangeEnd) { - return List.empty.asJava + return Array.empty } val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) @@ -167,55 +157,58 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: } (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( + new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] - }.toList.asJava + }.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + RateStreamMicroBatchReaderFactory } override def commit(end: Offset): Unit = {} override def stop(): Unit = {} - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchDataReaderFactory( +case class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { + relativeMsPerValue: Double) extends InputPartition - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +object RateStreamMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamMicroBatchInputPartition] + new RateStreamMicroBatchPartitionReader(p.partitionId, p.numPartitions, p.rangeStart, + p.rangeEnd, p.localStartTimeMs, p.relativeMsPerValue) + } } -class RateStreamMicroBatchDataReader( +class RateStreamMicroBatchPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { - private var count = 0 + relativeMsPerValue: Double) extends PartitionReader[InternalRow] { + private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd } - override def get(): Row = { + override def get(): InternalRow = { val currValue = rangeStart + partitionId + numPartitions * count count += 1 val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - Row( - DateTimeUtils.toJavaTimestamp( - DateTimeUtils.fromMillis(relative + localStartTimeMs)), - currValue - ) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), currValue) } override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6bdd492f0cb35..6942dfbfe0ecf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.execution.streaming.sources -import java.util.Optional - import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types._ /** @@ -42,13 +39,12 @@ import org.apache.spark.sql.types._ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { if (options.get(ROWS_PER_SECOND).isPresent) { val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong if (rowsPerSecond <= 0) { @@ -74,17 +70,14 @@ class RateStreamProvider extends DataSourceV2 } } - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) + new RateStreamMicroBatchReadSupport(options, checkpointLocation) } - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + options: DataSourceOptions): ContinuousReadSupport = { + new RateStreamContinuousReadSupport(options) + } override def shortName(): String = "rate" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 5f58246083bb2..2509450f0da9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -23,15 +23,21 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + import org.apache.spark.internal.Logging import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} +import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -39,13 +45,15 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { - override def createStreamWriter( +class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider + with MemorySinkBase with Logging { + + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + options: DataSourceOptions): StreamingWriteSupport = { + new MemoryStreamingWriteSupport(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -67,6 +75,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { @@ -96,7 +108,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case _ => throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + s"Output mode $outputMode is not supported by MemorySinkV2") } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -107,32 +119,30 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.clear() } - override def toString(): String = "MemorySink" + def numRows: Int = synchronized { + batches.foldLeft(0)(_ + _.data.length) + } + + override def toString(): String = "MemorySinkV2" } -case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} +case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) + extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) - extends DataSourceWriter with Logging { +class MemoryV2CustomMetrics(sink: MemorySinkV2) extends CustomMetrics { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(Map("numRows" -> sink.numRows)) +} - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) +class MemoryStreamingWriteSupport( + val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) + extends StreamingWriteSupport with SupportsCustomWriterMetrics { - def commit(messages: Array[WriterCommitMessage]): Unit = { - val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data - } - sink.write(batchId, outputMode, newRows) - } + private val customMemoryV2Metrics = new MemoryV2CustomMetrics(sink) - override def abort(messages: Array[WriterCommitMessage]): Unit = { - // Don't accept any of the new input. + override def createStreamingWriterFactory: MemoryWriterFactory = { + MemoryWriterFactory(outputMode, schema) } -} - -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) - extends StreamWriter { - - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -144,24 +154,36 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } + + override def getCustomMetrics: CustomMetrics = customMemoryV2Metrics } -case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { - override def createDataWriter( +case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) + extends DataWriterFactory with StreamingDataWriterFactory { + + override def createWriter( partitionId: Int, - attemptNumber: Int, - epochId: Long): DataWriter[Row] = { - new MemoryDataWriter(partitionId, outputMode) + taskId: Long): DataWriter[InternalRow] = { + new MemoryDataWriter(partitionId, outputMode, schema) + } + + override def createWriter( + partitionId: Int, + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { + createWriter(partitionId, taskId) } } -class MemoryDataWriter(partition: Int, outputMode: OutputMode) - extends DataWriter[Row] with Logging { +class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType) + extends DataWriter[InternalRow] with Logging { private val data = mutable.Buffer[Row]() - override def write(row: Row): Unit = { - data.append(row) + private val encoder = RowEncoder(schema).resolveAndBind() + + override def write(row: InternalRow): Unit = { + data.append(encoder.fromRow(row)) } override def commit(): MemoryWriterCommitMessage = { @@ -175,10 +197,10 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) /** - * Used to query the data that has been written into a [[MemorySink]]. + * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { - private val sizePerRow = output.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 5aae46b463398..b2a573eae504a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -19,25 +19,28 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket -import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.{Calendar, Locale} +import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String -object TextSocketMicroBatchReader { +object TextSocketReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) @@ -45,14 +48,12 @@ object TextSocketMicroBatchReader { } /** - * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and - * debugging. This MicroBatchReader will *not* work in production applications due to multiple - * reasons, including no support for fault recovery. + * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This MicroBatchReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { - - private var startOffset: Offset = _ - private var endOffset: Offset = _ +class TextSocketMicroBatchReadSupport(options: DataSourceOptions) + extends MicroBatchReadSupport with Logging { private val host: String = options.get("host").get() private val port: Int = options.get("port").get().toInt @@ -68,7 +69,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - private val batches = new ListBuffer[(String, Timestamp)] + private val batches = new ListBuffer[(UTF8String, Long)] @GuardedBy("this") private var currentOffset: LongOffset = LongOffset(-1L) @@ -76,7 +77,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR @GuardedBy("this") private var lastOffsetCommitted: LongOffset = LongOffset(-1L) - initialize() + private val initialized: AtomicBoolean = new AtomicBoolean(false) /** This method is only used for unit test */ private[sources] def getCurrentOffset(): LongOffset = synchronized { @@ -98,10 +99,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReader.this.synchronized { - val newData = (line, - Timestamp.valueOf( - TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + TextSocketMicroBatchReadSupport.this.synchronized { + val newData = ( + UTF8String.fromString(line), + DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) ) currentOffset += 1 batches.append(newData) @@ -115,40 +116,37 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR readThread.start() } - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { - startOffset = start.orElse(LongOffset(-1L)) - endOffset = end.orElse(currentOffset) - } - - override def getStartOffset(): Offset = { - Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) - } + override def initialOffset(): Offset = LongOffset(-1L) - override def getEndOffset(): Offset = { - Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) - } + override def latestOffset(): Offset = currentOffset override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def readSchema(): StructType = { + override def fullSchema(): StructType = { if (options.getBoolean("includeTimestamp", false)) { - TextSocketMicroBatchReader.SCHEMA_TIMESTAMP + TextSocketReader.SCHEMA_TIMESTAMP } else { - TextSocketMicroBatchReader.SCHEMA_REGULAR + TextSocketReader.SCHEMA_REGULAR } } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - assert(startOffset != null && endOffset != null, - "start offset and end offset should already be set before create read tasks.") + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } - val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 - val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 + val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { + if (initialized.compareAndSet(false, true)) { + initialize() + } + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 batches.slice(sliceStart, sliceEnd) @@ -158,15 +156,19 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR val spark = SparkSession.getActiveSession.get val numPartitions = spark.sparkContext.defaultParallelism - val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + val slices = Array.fill(numPartitions)(new ListBuffer[(UTF8String, Long)]) rawList.zipWithIndex.foreach { case (r, idx) => slices(idx % numPartitions).append(r) } - (0 until numPartitions).map { i => - val slice = slices(i) - new DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new DataReader[Row] { + slices.map(TextSocketInputPartition) + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + new PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val slice = partition.asInstanceOf[TextSocketInputPartition].slice + new PartitionReader[InternalRow] { private var currentIdx = -1 override def next(): Boolean = { @@ -174,14 +176,14 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR currentIdx < slice.size } - override def get(): Row = { - Row(slice(currentIdx)._1, slice(currentIdx)._2) + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) } override def close(): Unit = {} } } - }.toList.asJava + } } override def commit(end: Offset): Unit = synchronized { @@ -214,11 +216,14 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def toString: String = s"TextSocket[host: $host, port: $port]" + override def toString: String = s"TextSocketV2[host: $host, port: $port]" } +case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition + class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupport with DataSourceRegister with Logging { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider + with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + @@ -238,16 +243,18 @@ class TextSocketSourceProvider extends DataSourceV2 } } - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { checkParameters(options) - if (schema.isPresent) { - throw new AnalysisException("The socket source does not support a user-specified schema.") - } + new TextSocketMicroBatchReadSupport(options) + } - new TextSocketMicroBatchReader(options) + override def createContinuousReadSupport( + checkpointLocation: String, + options: DataSourceOptions): ContinuousReadSupport = { + checkParameters(options) + new TextSocketContinuousReadSupport(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala new file mode 100644 index 0000000000000..0a16a3819b778 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types._ + + +object FlatMapGroupsWithStateExecHelper { + + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + /** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ + case class StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + + private[FlatMapGroupsWithStateExecHelper] def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } + } + + /** Interface for interacting with state data of FlatMapGroupsWithState */ + sealed trait StateManager extends Serializable { + def stateSchema: StructType + def getState(store: StateStore, keyRow: UnsafeRow): StateData + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit + def removeState(store: StateStore, keyRow: UnsafeRow): Unit + def getAllState(store: StateStore): Iterator[StateData] + } + + def createStateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean, + stateFormatVersion: Int): StateManager = { + stateFormatVersion match { + case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp) + case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } + + // =============================================================================================== + // =========================== Private implementations of StateManager =========================== + // =============================================================================================== + + /** Commmon methods for StateManager implementations */ + private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) + extends StateManager { + + protected def stateSerializerExprs: Seq[Expression] + protected def stateDeserializerExpr: Expression + protected def timeoutTimestampOrdinalInRow: Int + + /** Get deserialized state and corresponding timeout timestamp for a key */ + override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew(keyRow, stateRow, getStateObject(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + override def putState(store: StateStore, key: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(key, stateRow) + } + + override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + override def getAllState(store: StateStore): Iterator[StateData] = { + val stateData = StateData() + store.getRange(None, None).map { p => + stateData.withNew(p.key, p.value, getStateObject(p.value), getTimestamp(p.value)) + } + } + + private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) + private lazy val stateDeserializerFunc = { + ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) + } + private lazy val stateDataForGets = StateData() + + protected def getStateObject(row: UnsafeRow): Any = { + if (row != null) stateDeserializerFunc(row) else null + } + + protected def getStateRow(obj: Any): UnsafeRow = { + stateSerializerFunc(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinalInRow) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinalInRow, timeoutTimestamps) + } + } + + /** + * Version 1 of the StateManager which stores the user-defined state as flattened columns in + * the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * + * UnsafeRow[ col1 | col2 | col3 | timestamp ] + * + * The limitation of this format is that timestamp cannot be set when the user-defined + * state has been removed. This is because the columns cannot be collectively marked to be + * empty/null. + */ + private class StateManagerImplV1( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + + override val stateSchema: StructType = stateAttributes.toStructType + + override val timeoutTimestampOrdinalInRow: Int = { + stateAttributes.indexOf(timestampTimeoutAttribute) + } + + override val stateSerializerExprs: Seq[Expression] = { + val encoderSerializer = stateEncoder.namedExpressions + if (shouldStoreTimestamp) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + stateEncoder.resolveAndBind().deserializer + } + + override protected def getStateRow(obj: Any): UnsafeRow = { + require(obj != null, "State object cannot be null") + super.getStateRow(obj) + } + } + + /** + * Version 2 of the StateManager which stores the user-defined state as a nested struct + * in the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * ___________________________ + * | | + * | V + * UnsafeRow[ nested-struct | timestamp | UnsafeRow[ col1 | col2 | col3 ] ] + * + * This allows the entire user-defined state to be collectively marked as empty/null, + * thus allowing timestamp to be set without requiring the state to be present. + */ + private class StateManagerImplV2( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + /** Schema of the state rows saved in the state store */ + override val stateSchema: StructType = { + var schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema = schema.add("timeoutTimestamp", LongType, nullable = false) + schema + } + + // Ordinals of the information stored in the state row + private val nestedStateOrdinal = 0 + override val timeoutTimestampOrdinalInRow = 1 + + override val stateSerializerExprs: Seq[Expression] = { + val boundRefToSpecificInternalRow = BoundReference( + 0, stateEncoder.serializer.head.collect { case b: BoundReference => b.dataType }.head, true) + + val nestedStateSerExpr = + CreateNamedStruct(stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + + val nullSafeNestedStateSerExpr = { + val nullLiteral = Literal(null, nestedStateSerExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToSpecificInternalRow) -> nullLiteral), nestedStateSerExpr) + } + + if (shouldStoreTimestamp) { + Seq(nullSafeNestedStateSerExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nullSafeNestedStateSerExpr) + } + } + + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + val boundRefToNestedState = + BoundReference(nestedStateOrdinal, stateEncoder.schema, nullable = true) + val deserExpr = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + val nullLiteral = Literal(null, deserExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToNestedState) -> nullLiteral), elseValue = deserExpr) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index df722b953228b..92a2480e8b017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ -import java.nio.channels.ClosedChannelException +import java.util import java.util.Locale +import java.util.concurrent.atomic.LongAdder import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams @@ -166,7 +166,16 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def metrics: StateStoreMetrics = { - StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate), Map.empty) + // NOTE: we provide estimation of cache size as "memoryUsedBytes", and size of state for + // current version as "stateOnCurrentVersionSizeBytes" + val metricsFromProvider: Map[String, Long] = getMetricsForProvider() + + val customMetrics = metricsFromProvider.flatMap { case (name, value) => + // just allow searching from list cause the list is small enough + supportedCustomMetrics.find(_.name == name).map(_ -> value) + } + (metricStateOnCurrentVersionSizeBytes -> SizeEstimator.estimate(mapToUpdate)) + + StateStoreMetrics(mapToUpdate.size(), metricsFromProvider("memoryUsedBytes"), customMetrics) } /** @@ -181,6 +190,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } } + def getMetricsForProvider(): Map[String, Long] = synchronized { + Map("memoryUsedBytes" -> SizeEstimator.estimate(loadedMaps), + metricLoadedMapCacheHit.name -> loadedMapCacheHitCount.sum(), + metricLoadedMapCacheMiss.name -> loadedMapCacheMissCount.sum()) + } + /** Get the state store for making updates to create a new `version` of the store. */ override def getStore(version: Long): StateStore = synchronized { require(version >= 0, "Version cannot be less than 0") @@ -205,6 +220,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf + this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory fm.mkdirs(baseDir) } @@ -222,11 +238,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def close(): Unit = { - loadedMaps.values.foreach(_.clear()) + loadedMaps.values.asScala.foreach(_.clear()) } override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { - Nil + metricStateOnCurrentVersionSizeBytes :: metricLoadedMapCacheHit :: metricLoadedMapCacheMiss :: + Nil } override def toString(): String = { @@ -241,18 +258,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ + @volatile private var numberOfVersionsToRetainInMemory: Int = _ - private lazy val loadedMaps = new mutable.HashMap[Long, MapType] + private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse) private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + private val loadedMapCacheHitCount: LongAdder = new LongAdder + private val loadedMapCacheMissCount: LongAdder = new LongAdder + + private lazy val metricStateOnCurrentVersionSizeBytes: StateStoreCustomSizeMetric = + StateStoreCustomSizeMetric("stateOnCurrentVersionSizeBytes", + "estimated size of state only on current version") + + private lazy val metricLoadedMapCacheHit: StateStoreCustomMetric = + StateStoreCustomSumMetric("loadedMapCacheHitCount", + "count of cache hit on states cache in provider") + + private lazy val metricLoadedMapCacheMiss: StateStoreCustomMetric = + StateStoreCustomSumMetric("loadedMapCacheMissCount", + "count of cache miss on states cache in provider") + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { finalizeDeltaFile(output) - loadedMaps.put(newVersion, map) + putStateIntoStateCacheMap(newVersion, map) } } @@ -262,7 +295,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized { val versionsInFiles = fetchFiles().map(_.version).toSet - val versionsLoaded = loadedMaps.keySet + val versionsLoaded = loadedMaps.keySet.asScala val allKnownVersions = versionsInFiles ++ versionsLoaded val unsafeRowTuple = new UnsafeRowPair() if (allKnownVersions.nonEmpty) { @@ -272,46 +305,92 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } else Iterator.empty } + /** This method is intended to be only used for unit test(s). DO NOT TOUCH ELEMENTS IN MAP! */ + private[state] def getLoadedMaps(): util.SortedMap[Long, MapType] = synchronized { + // shallow copy as a minimal guard + loadedMaps.clone().asInstanceOf[util.SortedMap[Long, MapType]] + } + + private def putStateIntoStateCacheMap(newVersion: Long, map: MapType): Unit = synchronized { + if (numberOfVersionsToRetainInMemory <= 0) { + if (loadedMaps.size() > 0) loadedMaps.clear() + return + } + + while (loadedMaps.size() > numberOfVersionsToRetainInMemory) { + loadedMaps.remove(loadedMaps.lastKey()) + } + + val size = loadedMaps.size() + if (size == numberOfVersionsToRetainInMemory) { + val versionIdForLastKey = loadedMaps.lastKey() + if (versionIdForLastKey > newVersion) { + // this is the only case which we can avoid putting, because new version will be placed to + // the last key and it should be evicted right away + return + } else if (versionIdForLastKey < newVersion) { + // this case needs removal of the last key before putting new one + loadedMaps.remove(versionIdForLastKey) + } + } + + loadedMaps.put(newVersion, map) + } + /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { // Shortcut if the map for this version is already there to avoid a redundant put. - val loadedCurrentVersionMap = synchronized { loadedMaps.get(version) } + val loadedCurrentVersionMap = synchronized { Option(loadedMaps.get(version)) } if (loadedCurrentVersionMap.isDefined) { + loadedMapCacheHitCount.increment() return loadedCurrentVersionMap.get } - val snapshotCurrentVersionMap = readSnapshotFile(version) - if (snapshotCurrentVersionMap.isDefined) { - synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } - return snapshotCurrentVersionMap.get - } - // Find the most recent map before this version that we can. - // [SPARK-22305] This must be done iteratively to avoid stack overflow. - var lastAvailableVersion = version - var lastAvailableMap: Option[MapType] = None - while (lastAvailableMap.isEmpty) { - lastAvailableVersion -= 1 + logWarning(s"The state for version $version doesn't exist in loadedMaps. " + + "Reading snapshot file and delta files if needed..." + + "Note that this is normal for the first batch of starting query.") - if (lastAvailableVersion <= 0) { - // Use an empty map for versions 0 or less. - lastAvailableMap = Some(new MapType) - } else { - lastAvailableMap = - synchronized { loadedMaps.get(lastAvailableVersion) } - .orElse(readSnapshotFile(lastAvailableVersion)) + loadedMapCacheMissCount.increment() + + val (result, elapsedMs) = Utils.timeTakenMs { + val snapshotCurrentVersionMap = readSnapshotFile(version) + if (snapshotCurrentVersionMap.isDefined) { + synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } + return snapshotCurrentVersionMap.get } - } - // Load all the deltas from the version after the last available one up to the target version. - // The last available version is the one with a full snapshot, so it doesn't need deltas. - val resultMap = new MapType(lastAvailableMap.get) - for (deltaVersion <- lastAvailableVersion + 1 to version) { - updateFromDeltaFile(deltaVersion, resultMap) + // Find the most recent map before this version that we can. + // [SPARK-22305] This must be done iteratively to avoid stack overflow. + var lastAvailableVersion = version + var lastAvailableMap: Option[MapType] = None + while (lastAvailableMap.isEmpty) { + lastAvailableVersion -= 1 + + if (lastAvailableVersion <= 0) { + // Use an empty map for versions 0 or less. + lastAvailableMap = Some(new MapType) + } else { + lastAvailableMap = + synchronized { Option(loadedMaps.get(lastAvailableVersion)) } + .orElse(readSnapshotFile(lastAvailableVersion)) + } + } + + // Load all the deltas from the version after the last available one up to the target version. + // The last available version is the one with a full snapshot, so it doesn't need deltas. + val resultMap = new MapType(lastAvailableMap.get) + for (deltaVersion <- lastAvailableVersion + 1 to version) { + updateFromDeltaFile(deltaVersion, resultMap) + } + + synchronized { putStateIntoStateCacheMap(version, resultMap) } + resultMap } - synchronized { loadedMaps.put(version, resultMap) } - resultMap + logDebug(s"Loading state for $version takes $elapsedMs ms.") + + result } private def writeUpdateToDeltaFile( @@ -490,15 +569,18 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Perform a snapshot of the store to allow delta files to be consolidated */ private def doSnapshot(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val lastVersion = files.last.version val deltaFilesForLastVersion = filesForVersion(files, lastVersion).filter(_.isSnapshot == false) - synchronized { loadedMaps.get(lastVersion) } match { + synchronized { Option(loadedMaps.get(lastVersion)) } match { case Some(map) => if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { - writeSnapshotFile(lastVersion, map) + val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map)) + logDebug(s"writeSnapshotFile() took $e2 ms.") } case None => // The last map is not loaded, probably some other instance is in charge @@ -517,19 +599,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def cleanup(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head - synchronized { - val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq - mapsToRemove.foreach(loadedMaps.remove) - } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) - filesToDelete.foreach { f => - fm.delete(f.path) + val (_, e2) = Utils.timeTakenMs { + filesToDelete.foreach { f => + fm.delete(f.path) + } } + logDebug(s"deleting files took $e2 ms.") logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7eb68c21569ba..d3313b8a315c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -138,6 +138,8 @@ trait StateStoreCustomMetric { def name: String def desc: String } + +case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 765ff076cb467..d145082a39b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -34,6 +34,9 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) /** Minimum versions a State Store implementation should retain to allow rollbacks */ val minVersionsToRetain: Int = sqlConf.minBatchesToRetain + /** Maximum count of versions a State Store implementation should retain in memory */ + val maxVersionsToRetainInMemory: Int = sqlConf.maxBatchesToRetainInMemory + /** * Optional fully qualified name of the subclass of [[StateStoreProvider]] * managing state data. That is, the implementation of the State Store to use. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 01d8e75980993..3f11b8f79943c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( StateStoreId(checkpointLocation, operatorId, partition.index), queryRunId) + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => storeVersion + case Some(value) => value + } + store = StateStore.get( - storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala new file mode 100644 index 0000000000000..9bfb9561b42a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.types.StructType + +/** + * Base trait for state manager purposed to be used from streaming aggregations. + */ +sealed trait StreamingAggregationStateManager extends Serializable { + + /** Extract columns consisting key from input row, and return the new row for key columns. */ + def getKey(row: UnsafeRow): UnsafeRow + + /** Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. */ + def getStateValueSchema: StructType + + /** Get the current value of a non-null key from the target state store. */ + def get(store: StateStore, key: UnsafeRow): UnsafeRow + + /** + * Put a new value for a non-null key to the target state store. Note that key will be + * extracted from the input row, and the key would be same as the result of getKey(inputRow). + */ + def put(store: StateStore, row: UnsafeRow): Unit + + /** + * Commit all the updates that have been made to the target state store, and return the + * new version. + */ + def commit(store: StateStore): Long + + /** Remove a single non-null key from the target state store. */ + def remove(store: StateStore, key: UnsafeRow): Unit + + /** Return an iterator containing all the key-value pairs in target state store. */ + def iterator(store: StateStore): Iterator[UnsafeRowPair] + + /** Return an iterator containing all the keys in target state store. */ + def keys(store: StateStore): Iterator[UnsafeRow] + + /** Return an iterator containing all the values in target state store. */ + def values(store: StateStore): Iterator[UnsafeRow] +} + +object StreamingAggregationStateManager extends Logging { + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + def createStateManager( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute], + stateFormatVersion: Int): StreamingAggregationStateManager = { + stateFormatVersion match { + case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) + case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } +} + +abstract class StreamingAggregationStateManagerBaseImpl( + protected val keyExpressions: Seq[Attribute], + protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { + + @transient protected lazy val keyProjector = + GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) + + override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) + + override def commit(store: StateStore): Long = store.commit() + + override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) + + override def keys(store: StateStore): Iterator[UnsafeRow] = { + // discard and don't convert values to avoid computation + store.getRange(None, None).map(_.key) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 1. + * In state version 1, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: Same as input row attributes. The schema of value contains key expressions as well. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV1( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + override def getStateValueSchema: StructType = inputRowAttributes.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + store.get(key) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + store.put(getKey(row), row) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator() + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(_.value) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 2. + * In state version 2, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: The diff between input row attributes and key expressions. + * + * The schema of value is changed to optimize the memory/space usage in state, via removing + * duplicated columns in key-value pair. Hence key columns are excluded from the schema of value. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV2( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) + private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + + // flag to check whether the row needs to be project into input row attributes after join + // e.g. if the fields in the joined row are not in the expected order + private val needToProjectToRestoreValue: Boolean = + keyValueJoinedExpressions != inputRowAttributes + + @transient private lazy val valueProjector = + GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) + + @transient private lazy val joiner = + GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( + inputRowAttributes, keyValueJoinedExpressions) + + override def getStateValueSchema: StructType = valueExpressions.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + val savedState = store.get(key) + if (savedState == null) { + return savedState + } + + restoreOriginalRow(key, savedState) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + val key = keyProjector(row) + val value = valueProjector(row) + store.put(key, value) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))) + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(rowPair => restoreOriginalRow(rowPair)) + } + + private def restoreOriginalRow(rowPair: UnsafeRowPair): UnsafeRow = { + restoreOriginalRow(rowPair.key, rowPair.value) + } + + private def restoreOriginalRow(key: UnsafeRow, value: UnsafeRow): UnsafeRow = { + val joinedRow = joiner.join(key, value) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 6b386308c79fb..352b3d3616fba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -269,10 +269,15 @@ class SymmetricHashJoinStateManager( keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes, keyWithIndexToValueMetrics.customMetrics.map { + case (s @ StateStoreCustomSumMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value case (s @ StateStoreCustomSizeMetric(_, desc), value) => s.copy(desc = newDesc(desc)) -> value case (s @ StateStoreCustomTimingMetric(_, desc), value) => s.copy(desc = newDesc(desc)) -> value + case (s, _) => + throw new IllegalArgumentException( + s"Unknown state store custom metric is found at metrics: $s") } ) } @@ -290,7 +295,7 @@ class SymmetricHashJoinStateManager( private val keyWithIndexToValue = new KeyWithIndexToValueStore() // Clean up any state store resources if necessary at the end of the task - Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ => abortIfNeeded() } } + Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } /** Helper trait for invoking common functionalities of a state store. */ private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 0b32327e51dbf..b6021438e902b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -61,7 +61,7 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) val wrappedF = (store: StateStore, iter: Iterator[T]) => { // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener(_ => { + TaskContext.get().addTaskCompletionListener[Unit](_ => { if (!store.hasCommitted) store.abort() }) cleanedF(store, iter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b9b07a2e688f9..c11af345b0248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ -import org.apache.spark.util.{CompletionIterator, NextIterator} +import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} /** Used to identify the state store for a given operator. */ @@ -90,19 +90,22 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => * the driver after this SparkPlan has been executed and metrics have been updated. */ def getProgress(): StateOperatorProgress = { + val customMetrics = stateStoreCustomMetrics + .map(entry => entry._1 -> longMetric(entry._1).value) + + val javaConvertedCustomMetrics: java.util.HashMap[String, java.lang.Long] = + new java.util.HashMap(customMetrics.mapValues(long2Long).asJava) + new StateOperatorProgress( numRowsTotal = longMetric("numTotalStateRows").value, numRowsUpdated = longMetric("numUpdatedStateRows").value, - memoryUsedBytes = longMetric("stateMemory").value) + memoryUsedBytes = longMetric("stateMemory").value, + javaConvertedCustomMetrics + ) } /** Records the duration of running `body` for the next query progress update. */ - protected def timeTakenMs(body: => Unit): Long = { - val startTime = System.nanoTime() - val result = body - val endTime = System.nanoTime() - math.max(NANOSECONDS.toMillis(endTime - startTime), 0) - } + protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** * Set the SQL metrics related to the state store. @@ -120,12 +123,20 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => private def stateStoreCustomMetrics: Map[String, SQLMetric] = { val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass) provider.supportedCustomMetrics.map { + case StateStoreCustomSumMetric(name, desc) => + name -> SQLMetrics.createMetric(sparkContext, desc) case StateStoreCustomSizeMetric(name, desc) => name -> SQLMetrics.createSizeMetric(sparkContext, desc) case StateStoreCustomTimingMetric(name, desc) => name -> SQLMetrics.createTimingMetric(sparkContext, desc) }.toMap } + + /** + * Should the MicroBatchExecution run another batch based on this stateful operator and the + * current updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false } /** An operator that supports watermark. */ @@ -166,6 +177,18 @@ trait WatermarkSupport extends UnaryExecNode { } } } + + protected def removeKeysOlderThanWatermark( + storeManager: StreamingAggregationStateManager, + store: StateStore): Unit = { + if (watermarkPredicateForKeys.nonEmpty) { + storeManager.keys(store).foreach { keyRow => + if (watermarkPredicateForKeys.get.eval(keyRow)) { + storeManager.remove(store, keyRow) + } + } + } + } } object WatermarkSupport { @@ -200,20 +223,23 @@ object WatermarkSupport { case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val hasInput = iter.hasNext if (!hasInput && keyExpressions.isEmpty) { // If our `keyExpressions` are empty, we're getting a global aggregation. In that case @@ -223,10 +249,10 @@ case class StateStoreRestoreExec( store.iterator().map(_.value) } else { iter.flatMap { row => - val key = getKey(row) - val savedState = store.get(key) + val key = stateManager.getKey(row.asInstanceOf[UnsafeRow]) + val restoredRow = stateManager.get(store, key) numOutputRows += 1 - Option(savedState).toSeq :+ row + Option(restoredRow).toSeq :+ row } } } @@ -253,9 +279,13 @@ case class StateStoreSaveExec( stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, @@ -264,11 +294,10 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") @@ -281,19 +310,18 @@ case class StateStoreSaveExec( allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numUpdatedStateRows += 1 } } allRemovalsTimeMs += 0 commitTimeMs += timeTakenMs { - store.commit() + stateManager.commit(store) } setStoreMetrics(store) - store.iterator().map { rowPair => + stateManager.values(store).map { valueRow => numOutputRows += 1 - rowPair.value + valueRow } // Update and output only rows being evicted from the StateStore @@ -303,14 +331,13 @@ case class StateStoreSaveExec( val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) while (filteredIter.hasNext) { val row = filteredIter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numUpdatedStateRows += 1 } } val removalStartTimeNs = System.nanoTime - val rangeIter = store.getRange(None, None) + val rangeIter = stateManager.iterator(store) new NextIterator[InternalRow] { override protected def getNext(): InternalRow = { @@ -318,7 +345,7 @@ case class StateStoreSaveExec( while(rangeIter.hasNext && removedValueRow == null) { val rowPair = rangeIter.next() if (watermarkPredicateForKeys.get.eval(rowPair.key)) { - store.remove(rowPair.key) + stateManager.remove(store, rowPair.key) removedValueRow = rowPair.value } } @@ -332,7 +359,7 @@ case class StateStoreSaveExec( override protected def close(): Unit = { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) - commitTimeMs += timeTakenMs { store.commit() } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } @@ -340,37 +367,36 @@ case class StateStoreSaveExec( // Update and output modified rows from the StateStore. case Some(Update) => - val updatesStartTimeNs = System.nanoTime - - new Iterator[InternalRow] { - + new NextIterator[InternalRow] { // Filter late date using watermark if specified private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } + private val updatesStartTimeNs = System.nanoTime - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - - // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) - false + override protected def getNext(): InternalRow = { + if (baseIterator.hasNext) { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + stateManager.put(store, row) + numOutputRows += 1 + numUpdatedStateRows += 1 + row } else { - true + finished = true + null } } - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) - numOutputRows += 1 - numUpdatedStateRows += 1 - row + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { + removeKeysOlderThanWatermark(stateManager, store) + } + commitTimeMs += timeTakenMs { stateManager.commit(store) } + setStoreMetrics(store) } } @@ -390,6 +416,12 @@ case class StateStoreSaveExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } /** Physical operator for executing streaming Deduplicate. */ @@ -456,6 +488,10 @@ case class StreamingDeduplicateExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } object StreamingDeduplicateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 582528777f90e..a7a24ac3641b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -58,21 +58,21 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L _content ++= new RunningExecutionTable( parent, s"Running Queries (${running.size})", currentTime, - running.sortBy(_.submissionTime).reverse).toNodeSeq + running.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (completed.nonEmpty) { _content ++= new CompletedExecutionTable( parent, s"Completed Queries (${completed.size})", currentTime, - completed.sortBy(_.submissionTime).reverse).toNodeSeq + completed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (failed.nonEmpty) { _content ++= new FailedExecutionTable( parent, s"Failed Queries (${failed.size})", currentTime, - failed.sortBy(_.submissionTime).reverse).toNodeSeq + failed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } _content } @@ -111,7 +111,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } - UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000)) } } @@ -133,7 +133,10 @@ private[ui] abstract class ExecutionTable( protected def header: Seq[String] - protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { + protected def row( + request: HttpServletRequest, + currentTime: Long, + executionUIData: SQLExecutionUIData): Seq[Node] = { val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - submissionTime @@ -141,7 +144,7 @@ private[ui] abstract class ExecutionTable( def jobLinks(status: JobExecutionStatus): Seq[Node] = { executionUIData.jobs.flatMap { case (jobId, jobStatus) => if (jobStatus == status) { - [{jobId.toString}] + [{jobId.toString}] } else { None } @@ -153,7 +156,7 @@ private[ui] abstract class ExecutionTable( {executionUIData.executionId.toString} - {descriptionCell(executionUIData)} + {descriptionCell(request, executionUIData)} {UIUtils.formatDate(submissionTime)} @@ -179,7 +182,9 @@ private[ui] abstract class ExecutionTable( } - private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { + private def descriptionCell( + request: HttpServletRequest, + execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details != null && execution.details.nonEmpty) { +details @@ -192,27 +197,28 @@ private[ui] abstract class ExecutionTable( } val desc = if (execution.description != null && execution.description.nonEmpty) { - {execution.description} + {execution.description} } else { - {execution.executionId} + {execution.executionId} }
      {desc} {details}
      } - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = {

      {tableName}

      {UIUtils.listingTable[SQLExecutionUIData]( - header, row(currentTime, _), executionUIDatas, id = Some(tableId))} + header, row(request, currentTime, _), executionUIDatas, id = Some(tableId))}
      } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) - private def executionURL(executionID: Long): String = - s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" + private def executionURL(request: HttpServletRequest, executionID: Long): String = + s"${UIUtils.prependBaseUri( + request, parent.basePath)}/${parent.prefix}/execution/?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e0554f0c4d337..877176b030f8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -49,7 +49,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
    • {label} {jobs.toSeq.sorted.map { jobId => - {jobId.toString}  + {jobId.toString}  }}
    • } else { @@ -77,27 +77,31 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging val graph = sqlStore.planGraph(executionId) summary ++ - planVisualization(metrics, graph) ++ + planVisualization(request, metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
      No information to display for query {executionId}
      } - UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) + UIUtils.headerSparkPage( + request, s"Details for Query $executionId", content, parent, Some(5000)) } - private def planVisualizationResources: Seq[Node] = { + private def planVisualizationResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - - - + + + + + // scalastyle:on } - private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { + private def planVisualization( + request: HttpServletRequest, + metrics: Map[Long, String], + graph: SparkPlanGraph): Seq[Node] = { val metadata = graph.allNodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
      {node.desc}
      @@ -112,13 +116,13 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
      {graph.allNodes.size.toString}
      {metadata} - {planVisualizationResources} + {planVisualizationResources(request)} } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
      diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2b6bb48467eb3..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -289,7 +289,7 @@ class SQLAppStatusListener( private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event Option(liveExecutions.get(executionId)).foreach { exec => - exec.driverAccumUpdates = accumUpdates.toMap + exec.driverAccumUpdates = exec.driverAccumUpdates ++ accumUpdates update(exec) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 800a2ea3f3996..fede0f3e92d67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -112,9 +112,11 @@ case class WindowExec( * * @param frame to evaluate. This can either be a Row or Range frame. * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + private[this] def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { (frame, bound) match { case (RowFrame, CurrentRow) => RowBoundOrdering(0) @@ -144,7 +146,7 @@ case class WindowExec( val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone)) + TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a== b => Add(expr, boundOffset) } val bound = newMutableProjection(boundExpr :: Nil, child.output) @@ -197,6 +199,7 @@ case class WindowExec( // Map the groups to a (unbound) expression and frame factory pair. var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone framedFunctions.toSeq.map { case (key, (expressions, functionSeq)) => val ordinal = numExpressions @@ -237,7 +240,7 @@ case class WindowExec( new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, upper, timeZone)) } // Shrinking Frame. @@ -246,7 +249,7 @@ case class WindowExec( new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower)) + createBoundOrdering(frameType, lower, timeZone)) } // Moving Frame. @@ -255,8 +258,8 @@ case class WindowExec( new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower), - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) } } @@ -320,8 +323,6 @@ case class WindowExec( fetchNextRow() // Manage the current partition. - val inputFields = child.output.length - val buffer: ExternalAppendOnlyUnsafeRowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index bdc4bb4422ae7..7bd20dbe8f6d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.types.DataType @@ -40,7 +41,7 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Option[Seq[DataType]]) { + inputTypes: Option[Seq[ScalaReflection.Schema]]) { private var _nameOption: Option[String] = None private var _nullable: Boolean = true @@ -72,10 +73,11 @@ case class UserDefinedFunction protected[sql] ( f, dataType, exprs.map(_.expr), - inputTypes.getOrElse(Nil), + inputTypes.map(_.map(_.dataType)).getOrElse(Nil), udfName = _nameOption, nullable = _nullable, - udfDeterministic = _deterministic)) + udfDeterministic = _deterministic, + nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil))) } private def copyAll(): UserDefinedFunction = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bea8c0e445002..a261a7c1752d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -39,7 +39,21 @@ import org.apache.spark.util.Utils /** - * Functions available for DataFrame operations. + * Commonly used functions available for DataFrame operations. Using functions defined here provides + * a little bit more compile-time safety to make sure the function exists. + * + * Spark also includes more built-in functions that are less common and are not defined here. + * You can still access them (and all the functions defined here) using the `functions.expr()` API + * and calling them through a SQL expression string. You can find the entire list of functions + * at SQL API documentation. + * + * As an example, `isnan` is a function that is defined here. You can use `isnan(col("myCol"))` + * to invoke the `isnan` function. This way the programming language's compiler ensures `isnan` + * exists and is of the proper form. You can also use `expr("isnan(myCol)")` function to invoke the + * same function. In this case, Spark itself will ensure `isnan` exists when it analyzes the query. + * + * `regr_count` is an example of a function that is built-in but not defined here, because it is + * less commonly used. To invoke it, use `expr("regr_count(yCol, xCol)")`. * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions @@ -283,6 +297,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -291,6 +308,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -299,6 +319,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -307,6 +330,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -422,6 +448,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -435,6 +464,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -448,6 +480,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -459,6 +494,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -535,6 +573,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -548,6 +589,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -561,6 +605,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -572,6 +619,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -775,6 +825,7 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -994,14 +1045,6 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Computes the absolute value. - * - * @group normal_funcs - * @since 1.3.0 - */ - def abs(e: Column): Column = withExpr { Abs(e.expr) } - /** * Creates a new array column. The input columns must all have the same data type. * @@ -1033,6 +1076,17 @@ object functions { @scala.annotation.varargs def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** + * Creates a new map column. The array in the first column is used for keys. The array in the + * second column is used for values. All elements in the array for key should not be null. + * + * @group normal_funcs + * @since 2.4 + */ + def map_from_arrays(keys: Column, values: Column): Column = withExpr { + MapFromArrays(keys.expr, values.expr) + } + /** * Marks a DataFrame as small enough for use in broadcast joins. * @@ -1172,7 +1226,7 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1183,6 +1237,8 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1192,7 +1248,7 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1203,6 +1259,8 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1211,7 +1269,7 @@ object functions { /** * Partition ID. * - * @note This is indeterministic because it depends on data partitioning and task scheduling. + * @note This is non-deterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs * @since 1.6.0 @@ -1284,7 +1342,7 @@ object functions { } /** - * Computes bitwise NOT. + * Computes bitwise NOT (~) of a number. * * @group normal_funcs * @since 1.4.0 @@ -1312,6 +1370,14 @@ object functions { // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value of a numeric value. + * + * @group math_funcs + * @since 1.3.0 + */ + def abs(e: Column): Column = withExpr { Abs(e.expr) } + /** * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * @@ -1594,7 +1660,7 @@ object functions { def expm1(e: Column): Column = withExpr { Expm1(e.expr) } /** - * Computes the exponential of the given column. + * Computes the exponential of the given column minus one. * * @group math_funcs * @since 1.4.0 @@ -2560,8 +2626,12 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns the date that is numMonths after startDate. + * Returns the date that is `numMonths` after `startDate`. * + * @param startDate A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param numMonths The number of months to add to `startDate`, can be negative to subtract months + * @return A date, or null if `startDate` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2589,12 +2659,15 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * A pattern `dd.MM.yyyy` would return a string like `18.03.1993`. - * All pattern letters of `java.text.SimpleDateFormat` can be used. + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns * + * @param dateExpr A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param format A pattern `dd.MM.yyyy` would return a string like `18.03.1993` + * @return A string, or null if `dateExpr` was a string that could not be cast to a timestamp * @note Use specialized functions like [[year]] whenever possible as they benefit from a * specialized implementation. - * + * @throws IllegalArgumentException if the `format` pattern is invalid * @group datetime_funcs * @since 1.5.0 */ @@ -2604,6 +2677,11 @@ object functions { /** * Returns the date that is `days` days after `start` + * + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days The number of days to add to `start`, can be negative to subtract days + * @return A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2611,6 +2689,11 @@ object functions { /** * Returns the date that is `days` days before `start` + * + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days The number of days to subtract from `start`, can be negative to add days + * @return A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2618,6 +2701,19 @@ object functions { /** * Returns the number of days from `start` to `end`. + * + * Only considers the date part of the input. For example: + * {{{ + * dateddiff("2018-01-10 00:00:00", "2018-01-09 23:59:59") + * // returns 1 + * }}} + * + * @param end A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return An integer, or null if either `end` or `start` were strings that could not be cast to + * a date. Negative if `end` is before `start` * @group datetime_funcs * @since 1.5.0 */ @@ -2625,6 +2721,7 @@ object functions { /** * Extracts the year as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2632,6 +2729,7 @@ object functions { /** * Extracts the quarter as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2639,6 +2737,7 @@ object functions { /** * Extracts the month as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2646,6 +2745,8 @@ object functions { /** * Extracts the day of the week as an integer from a given date/timestamp/string. + * Ranges from 1 for a Sunday through to 7 for a Saturday + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 2.3.0 */ @@ -2653,6 +2754,7 @@ object functions { /** * Extracts the day of the month as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2660,6 +2762,7 @@ object functions { /** * Extracts the day of the year as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2667,16 +2770,20 @@ object functions { /** * Extracts the hours as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ def hour(e: Column): Column = withExpr { Hour(e.expr) } /** - * Given a date column, returns the last day of the month which the given date belongs to. + * Returns the last day of the month which the given date belongs to. * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the * month in July 2015. * + * @param e A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A date, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2684,30 +2791,60 @@ object functions { /** * Extracts the minutes as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ def minute(e: Column): Column = withExpr { Minute(e.expr) } /** - * Returns number of months between dates `date1` and `date2`. + * Returns number of months between dates `start` and `end`. + * + * A whole number is returned if both inputs have the same day of month or both are the last day + * of their respective months. Otherwise, the difference is calculated assuming 31 days per month. + * + * For example: + * {{{ + * months_between("2017-11-14", "2017-07-14") // returns 4.0 + * months_between("2017-01-01", "2017-01-10") // returns 0.29032258 + * months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5 + * }}} + * + * @param end A date, timestamp or string. If a string, the data must be in a format that can + * be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start A date, timestamp or string. If a string, the data must be in a format that can + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A double, or null if either `end` or `start` were strings that could not be cast to a + * timestamp. Negative if `end` is before `start` * @group datetime_funcs * @since 1.5.0 */ - def months_between(date1: Column, date2: Column): Column = withExpr { - MonthsBetween(date1.expr, date2.expr) + def months_between(end: Column, start: Column): Column = withExpr { + new MonthsBetween(end.expr, start.expr) + } + + /** + * Returns number of months between dates `end` and `start`. If `roundOff` is set to true, the + * result is rounded off to 8 digits; it is not rounded otherwise. + * @group datetime_funcs + * @since 2.4.0 + */ + def months_between(end: Column, start: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(end.expr, start.expr, lit(roundOff).expr) } /** - * Given a date column, returns the first date which is later than the value of the date column - * that is on the specified day of the week. + * Returns the first date which is later than the value of the `date` column that is on the + * specified day of the week. * * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first * Sunday after 2015-07-27. * - * Day of the week parameter is case insensitive, and accepts: - * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". - * + * @param date A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param dayOfWeek Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun" + * @return A date, or null if `date` was a string that could not be cast to a date or if + * `dayOfWeek` was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2717,6 +2854,7 @@ object functions { /** * Extracts the seconds as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a timestamp * @group datetime_funcs * @since 1.5.0 */ @@ -2724,6 +2862,11 @@ object functions { /** * Extracts the week number as an integer from a given date/timestamp/string. + * + * A week is considered to start on a Monday and week 1 is the first week with more than 3 days, + * as defined by ISO 8601 + * + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2731,8 +2874,12 @@ object functions { /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string - * representing the timestamp of that moment in the current system time zone in the given - * format. + * representing the timestamp of that moment in the current system time zone in the + * yyyy-MM-dd HH:mm:ss format. + * + * @param ut A number of a type that is castable to a long, such as string or integer. Can be + * negative for timestamps before the unix epoch + * @return A string, or null if the input was a string that could not be cast to a long * @group datetime_funcs * @since 1.5.0 */ @@ -2744,6 +2891,14 @@ object functions { * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string * representing the timestamp of that moment in the current system time zone in the given * format. + * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param ut A number of a type that is castable to a long, such as string or integer. Can be + * negative for timestamps before the unix epoch + * @param f A date time pattern that the input will be formatted to + * @return A string, or null if `ut` was a string that could not be cast to a long or `f` was + * an invalid date time pattern * @group datetime_funcs * @since 1.5.0 */ @@ -2752,7 +2907,7 @@ object functions { } /** - * Returns the current Unix timestamp (in seconds). + * Returns the current Unix timestamp (in seconds) as a long. * * @note All calls of `unix_timestamp` within the same query return the same value * (i.e. the current timestamp is calculated at the start of query evaluation). @@ -2767,8 +2922,10 @@ object functions { /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), * using the default timezone and the default locale. - * Returns `null` if fails. * + * @param s A date, timestamp or string. If a string, the data must be in the + * `yyyy-MM-dd HH:mm:ss` format + * @return A long, or null if the input was a string not of the correct format * @group datetime_funcs * @since 1.5.0 */ @@ -2778,17 +2935,25 @@ object functions { /** * Converts time string with given pattern to Unix timestamp (in seconds). - * Returns `null` if fails. * - * @see - * Customizing Formats + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param p A date time pattern detailing the format of `s` when `s` is a string + * @return A long, or null if `s` was a string that could not be cast to a date or `p` was + * an invalid format * @group datetime_funcs * @since 1.5.0 */ def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } /** - * Convert time string to a Unix timestamp (in seconds) by casting rules to `TimestampType`. + * Converts to a timestamp by casting rules to `TimestampType`. + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A timestamp, or null if the input was a string that could not be cast to a timestamp * @group datetime_funcs * @since 2.2.0 */ @@ -2797,9 +2962,15 @@ object functions { } /** - * Convert time string to a Unix timestamp (in seconds) with a specified format - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix timestamp (in seconds), return null if fail. + * Converts time string with the given pattern to timestamp. + * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param fmt A date time pattern detailing the format of `s` when `s` is a string + * @return A timestamp, or null if `s` was a string that could not be cast to a timestamp or + * `fmt` was an invalid format * @group datetime_funcs * @since 2.2.0 */ @@ -2817,9 +2988,14 @@ object functions { /** * Converts the column into a `DateType` with a specified format - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * return null if fail. * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param e A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param fmt A date time pattern detailing the format of `e` when `e`is a string + * @return A date, or null if `e` was a string that could not be cast to a date or `fmt` was an + * invalid format * @group datetime_funcs * @since 2.2.0 */ @@ -2830,9 +3006,15 @@ object functions { /** * Returns date truncated to the unit specified by the format. * + * For example, `trunc("2018-11-19 12:01:19", "year")` returns 2018-01-01 + * + * @param date A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` * @param format: 'year', 'yyyy', 'yy' for truncate by year, * or 'month', 'mon', 'mm' for truncate by month * + * @return A date, or null if `date` was a string that could not be cast to a date or `format` + * was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2843,11 +3025,16 @@ object functions { /** * Returns timestamp truncated to the unit specified by the format. * + * For example, `date_tunc("2018-11-19 12:01:19", "year")` returns 2018-01-01 00:00:00 + * * @param format: 'year', 'yyyy', 'yy' for truncate by year, * 'month', 'mon', 'mm' for truncate by month, * 'day', 'dd' for truncate by day, * Other options are: 'second', 'minute', 'hour', 'week', 'month', 'quarter' - * + * @param timestamp A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A timestamp, or null if `timestamp` was a string that could not be cast to a timestamp + * or `format` was an invalid value * @group datetime_funcs * @since 2.3.0 */ @@ -2859,6 +3046,13 @@ object functions { * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield * '2017-07-14 03:40:00.0'. + * + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone that the input should be adjusted to, such as + * `Europe/London`, `PST` or `GMT+5` + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2866,10 +3060,28 @@ object functions { FromUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def from_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + FromUTCTimestamp(ts.expr, tz.expr) + } + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield * '2017-07-14 01:40:00.0'. + * + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone that the input belongs to, such as `Europe/London`, + * `PST` or `GMT+5` + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2877,6 +3089,17 @@ object functions { ToUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def to_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + ToUTCTimestamp(ts.expr, tz.expr) + } + /** * Bucketize rows into one or more time windows given a timestamp specifying column. Window * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window @@ -3025,7 +3248,47 @@ object functions { * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, Literal(value)) + ArrayContains(column.expr, lit(value).expr) + } + + /** + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both + * the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns + * `false` otherwise. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_overlap(a1: Column, a2: Column): Column = withExpr { + ArraysOverlap(a1.expr, a2.expr) + } + + /** + * Returns an array containing all the elements in `x` from index `start` (or starting from the + * end if `start` is negative) with the specified `length`. + * @group collection_funcs + * @since 2.4.0 + */ + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, Literal(start), Literal(length)) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + * `nullReplacement`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement))) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), None) } /** @@ -3049,7 +3312,7 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, Literal(value)) + ArrayPosition(column.expr, lit(value).expr) } /** @@ -3060,7 +3323,65 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, Literal(value)) + ElementAt(column.expr, lit(value).expr) + } + + /** + * Sorts the input array in ascending order. The elements of the input array must be orderable. + * Null elements will be placed at the end of the returned array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + + /** + * Remove all elements that equal to element from the given array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_remove(column: Column, element: Any): Column = withExpr { + ArrayRemove(column.expr, lit(element).expr) + } + + /** + * Removes duplicate values from the array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + + /** + * Returns an array of the elements in the intersection of the given two arrays, + * without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_intersect(col1: Column, col2: Column): Column = withExpr { + ArrayIntersect(col1.expr, col2.expr) + } + + /** + * Returns an array of the elements in the union of the given two arrays, without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_union(col1: Column, col2: Column): Column = withExpr { + ArrayUnion(col1.expr, col2.expr) + } + + /** + * Returns an array of the elements in the first array but not in the second array, + * without duplicates. The order of elements in the result is not determined + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_except(col1: Column, col2: Column): Column = withExpr { + ArrayExcept(col1.expr, col2.expr) } /** @@ -3136,9 +3457,9 @@ object functions { from_json(e, schema.asInstanceOf[DataType], options) /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3168,9 +3489,9 @@ object functions { from_json(e, schema, options.asScala.toMap) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3197,8 +3518,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s - * with the specified schema. Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, + * `StructType` or `ArrayType` with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3210,9 +3532,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, @@ -3227,9 +3549,9 @@ object functions { } /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string, it could be a @@ -3242,11 +3564,53 @@ object functions { val dataType = try { DataType.fromJson(schema) } catch { - case NonFatal(_) => StructType.fromDDL(schema) + case NonFatal(_) => DataType.fromDDL(schema) } from_json(e, dataType, options) } + /** + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column): Column = { + from_json(e, schema, Map.empty[String, String].asJava) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap)) + } + + /** + * Parses a column containing a JSON string and infers its schema. + * + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.4.0 + */ + def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + /** * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. @@ -3302,6 +3666,7 @@ object functions { /** * Sorts the input array for the given column in ascending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array. * * @group collection_funcs * @since 1.5.0 @@ -3311,6 +3676,8 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array in ascending order or + * at the end of the returned array in descending order. * * @group collection_funcs * @since 1.5.0 @@ -3333,6 +3700,16 @@ object functions { */ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** + * Returns a random permutation of the given array. + * + * @note The function is non-deterministic. + * + * @group collection_funcs + * @since 2.4.0 + */ + def shuffle(e: Column): Column = withExpr { Shuffle(e.expr) } + /** * Returns a reversed string or an array with reverse order of elements. * @group collection_funcs @@ -3340,6 +3717,55 @@ object functions { */ def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than + * two levels, only one level of nesting is removed. + * @group collection_funcs + * @since 2.4.0 + */ + def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + + /** + * Generate a sequence of integers from start to stop, incrementing by step. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column, step: Column): Column = withExpr { + new Sequence(start.expr, stop.expr, step.expr) + } + + /** + * Generate a sequence of integers from start to stop, + * incrementing by 1 if start is less than or equal to stop, otherwise -1. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column): Column = withExpr { + new Sequence(start.expr, stop.expr) + } + + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(left: Column, right: Column): Column = withExpr { + ArrayRepeat(left.expr, right.expr) + } + + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count)) + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs @@ -3354,6 +3780,37 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + /** + * Returns an unordered array of all entries in the given map. + * @group collection_funcs + * @since 2.4.0 + */ + def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + + /** + * Returns a map created from the given array of entries. + * @group collection_funcs + * @since 2.4.0 + */ + def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } + + /** + * Returns a merged array of structs in which the N-th struct contains all N-th values of input + * arrays. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + + /** + * Returns the union of all the given maps. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number @@ -3362,7 +3819,7 @@ object functions { (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]) :: $s"}) println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -3436,7 +3893,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3452,7 +3909,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3468,7 +3925,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3484,7 +3941,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3500,7 +3957,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3516,7 +3973,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3532,7 +3989,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3548,7 +4005,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3564,7 +4021,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3580,7 +4037,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: ScalaReflection.schemaFor(typeTag[A10]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 6ae307bce10c8..4698e8ab13ce3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + val cascade = !sessionCatalog.isTemporaryTable(tableIdent) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), cascade) } /** @@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // cached version and make the new version cached lazily. if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, cascade = true, blocking = true) // Cache it again. sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index baea4ceebf8e3..5b6160e2b408f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -99,7 +99,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = { + lazy val externalCatalog: ExternalCatalogWithListener = { val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, @@ -117,14 +117,17 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } + // Wrap to provide catalog events + val wrapped = new ExternalCatalogWithListener(externalCatalog) + // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { + wrapped.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { sparkContext.listenerBus.post(event) } }) - externalCatalog + wrapped } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 8b92c8b4f56b5..3a3246a1b1d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -64,7 +64,16 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } } - override def getTruncateQuery(table: String): String = { - dialects.head.getTruncateQuery(table) + /** + * The SQL query used to truncate a table. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + dialects.head.getTruncateQuery(table, cascade) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38c..d13c29ed46bd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -41,4 +41,6 @@ private object DerbyDialect extends JdbcDialect { Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) case _ => None } + + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 83d87a11810c1..f76c1fae562c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -22,6 +22,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ /** @@ -120,12 +121,27 @@ abstract class JdbcDialect extends Serializable { * The SQL query that should be used to truncate a table. Dialects can override this method to * return a query that is suitable for a particular database. For PostgreSQL, for instance, * a different query is used to prevent "TRUNCATE" affecting other tables. - * @param table The name of the table. + * @param table The table to truncate * @return The SQL query to use for truncating a table */ @Since("2.3.0") def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE $table" + getTruncateQuery(table, isCascadingTruncateTable) + } + + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation + * @return The SQL query to use for truncating a table + */ + @Since("2.4.0") + def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"TRUNCATE TABLE $table" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 6ef77f24460be..f4a6d0a4d2e44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -95,4 +95,20 @@ private case object OracleDialect extends JdbcDialect { } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE $table CASCADE" + case _ => s"TRUNCATE TABLE $table" + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 13a2035f4d0c4..f8d2bc8e0f13f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,15 +85,27 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + /** - * The SQL query used to truncate a table. For Postgres, the default behaviour is to - * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, - * the Postgres dialect adds 'ONLY' to truncate only the table in question - * @param table The name of the table. - * @return The SQL query to use for truncating a table - */ - override def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE ONLY $table" + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the value of + * isCascadingTruncateTable(). Cascading a truncation will truncate tables + * with a foreign key relationship to the target table. However, it will not + * truncate tables with an inheritance relationship to the target table, as + * the truncate query always includes "ONLY" to prevent this behaviour. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE" + case _ => s"TRUNCATE TABLE ONLY $table" + } } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { @@ -110,5 +122,4 @@ private object PostgresDialect extends JdbcDialect { } } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 5749b791fca25..6c17bd7ed9ec4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -31,4 +31,22 @@ private case object TeradataDialect extends JdbcDialect { case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None } + + // Teradata does not support cascading a truncation + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. Teradata does not support the 'TRUNCATE' syntax that + * other dialects use. Instead, we need to use a 'DELETE FROM' statement. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable(). Teradata does not support cascading a + * 'DELETE FROM' statement (and as mentioned, does not support 'TRUNCATE' syntax) + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"DELETE FROM $table ALL" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 2499e9b604f3e..bdd8c4da6bd30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -199,7 +199,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to - * a string that starts with `value`. + * a string that ends with `value`. * * @since 1.3.1 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ae93965bc50ed..39e9e1ad426be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.{Locale, Optional} +import java.util.Locale import scala.collection.JavaConverters._ @@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,19 +172,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupport => - var tempReader: MicroBatchReader = null + case s: MicroBatchReadSupportProvider => + var tempReadSupport: MicroBatchReadSupport = null val schema = try { - tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) - tempReader.readSchema() + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createMicroBatchReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() } finally { // Stop tempReader to avoid side-effect thing - if (tempReader != null) { - tempReader.stop() - tempReader = null + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null } } Dataset.ofRows( @@ -192,16 +194,28 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo StreamingRelationV2( s, source, extraOptions.toMap, schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupport => - val tempReader = s.createContinuousReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + case s: ContinuousReadSupportProvider => + var tempReadSupport: ContinuousReadSupport = null + val schema = try { + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createContinuousReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) @@ -270,6 +284,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * per file *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
    • + *
    • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
    • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e12..7866e4f70f14b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -22,14 +22,15 @@ import java.util.Locale import scala.collection.JavaConverters._ import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.api.java.function.VoidFunction2 +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -269,7 +270,22 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) + val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) + df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + extraOptions.toMap, + sink, + outputMode, + useTempCheckpointLocation = true, + trigger = trigger) + } else if (source == "foreachBatch") { + assertNotPartitioned("foreachBatch") + if (trigger.isInstanceOf[ContinuousTrigger]) { + throw new AnalysisException("'foreachBatch' is not supported with continuous trigger") + } + val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -283,7 +299,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") val sink = ds.newInstance() match { - case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w + case w: StreamingWriteSupportProvider + if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, @@ -307,49 +324,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data - * generated by the `DataFrame`/`Dataset` to an external system. - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }).start() - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }).start(); - * }}} - * + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { @@ -362,6 +339,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * :: Experimental :: + * + * (Scala-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + this.source = "foreachBatch" + if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") + this.foreachBatchWriter = function + this + } + + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = { + foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) + } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => cols.map(normalize(_, "Partition")) } @@ -398,5 +414,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var foreachWriter: ForeachWriter[T] = null + private var foreachBatchWriter: (Dataset[T], Long) => Unit = null + private var partitioningColumns: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7cefd03e43bc3..cd52d991d55c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} @@ -32,7 +33,8 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -55,6 +57,19 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo @GuardedBy("awaitTerminationLock") private var lastTerminatedQuery: StreamingQuery = null + try { + sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[StreamingQueryListener], classNames, + sparkSession.sparkContext.conf).foreach(listener => { + addListener(listener) + logInfo(s"Registered listener ${listener.getClass.getName}") + }) + } + } catch { + case e: Exception => + throw new SparkException("Exception when registering StreamingQueryListener", e) + } + /** * Returns a list of active queries associated with this SQLContext * @@ -241,8 +256,10 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + } new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 0dcb666e2c3e4..cf9375d39b39d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -38,7 +38,8 @@ import org.apache.spark.annotation.InterfaceStability class StateOperatorProgress private[sql]( val numRowsTotal: Long, val numRowsUpdated: Long, - val memoryUsedBytes: Long + val memoryUsedBytes: Long, + val customMetrics: ju.Map[String, JLong] = new ju.HashMap() ) extends Serializable { /** The compact JSON representation of this progress. */ @@ -48,12 +49,20 @@ class StateOperatorProgress private[sql]( def prettyJson: String = pretty(render(jsonValue)) private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress = - new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes) + new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, customMetrics) private[sql] def jsonValue: JValue = { ("numRowsTotal" -> JInt(numRowsTotal)) ~ ("numRowsUpdated" -> JInt(numRowsUpdated)) ~ - ("memoryUsedBytes" -> JInt(memoryUsedBytes)) + ("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~ + ("customMetrics" -> { + if (!customMetrics.isEmpty) { + val keys = customMetrics.keySet.asScala.toSeq.sorted + keys.map { k => k -> JInt(customMetrics.get(k).toLong) : JObject }.reduce(_ ~ _) + } else { + JNothing + } + }) } override def toString: String = prettyJson @@ -163,7 +172,27 @@ class SourceProgress protected[sql]( val endOffset: String, val numInputRows: Long, val inputRowsPerSecond: Double, - val processedRowsPerSecond: Double) extends Serializable { + val processedRowsPerSecond: Double, + val customMetrics: String) extends Serializable { + + /** SourceProgress without custom metrics. */ + protected[sql] def this( + description: String, + startOffset: String, + endOffset: String, + numInputRows: Long, + inputRowsPerSecond: Double, + processedRowsPerSecond: Double) { + + this( + description, + startOffset, + endOffset, + numInputRows, + inputRowsPerSecond, + processedRowsPerSecond, + null) + } /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -178,12 +207,18 @@ class SourceProgress protected[sql]( if (value.isNaN || value.isInfinity) JNothing else JDouble(value) } - ("description" -> JString(description)) ~ + val jsonVal = ("description" -> JString(description)) ~ ("startOffset" -> tryParse(startOffset)) ~ ("endOffset" -> tryParse(endOffset)) ~ ("numInputRows" -> JInt(numInputRows)) ~ ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) + + if (customMetrics != null) { + jsonVal ~ ("customMetrics" -> parse(customMetrics)) + } else { + jsonVal + } } private def tryParse(json: String) = try { @@ -202,7 +237,13 @@ class SourceProgress protected[sql]( */ @InterfaceStability.Evolving class SinkProgress protected[sql]( - val description: String) extends Serializable { + val description: String, + val customMetrics: String) extends Serializable { + + /** SinkProgress without custom metrics. */ + protected[sql] def this(description: String) { + this(description, null) + } /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -213,6 +254,12 @@ class SinkProgress protected[sql]( override def toString: String = prettyJson private[sql] def jsonValue: JValue = { - ("description" -> JString(description)) + val jsonVal = ("description" -> JString(description)) + + if (customMetrics != null) { + jsonVal ~ ("customMetrics" -> parse(customMetrics)) + } else { + jsonVal + } } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index c132cab1b38cf..2c695fc58fd8c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.junit.*; import org.junit.rules.ExpectedException; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; @@ -336,6 +337,23 @@ public void testTupleEncoder() { Assert.assertEquals(data5, ds5.collectAsList()); } + @Test + public void testTupleEncoderSchema() { + Encoder>> encoder = + Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())); + List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")), + tuple2("2", tuple2("c", "d"))); + Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2"); + + JavaPairRDD> pairRDD = jsc.parallelizePairs(data); + Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder) + .toDF("value1", "value2"); + + Assert.assertEquals(ds1.schema(), ds2.schema()); + Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(), + ds2.select(expr("value2._1")).collectAsList()); + } + @Test public void testNestedTupleEncoder() { // test ((int, string), string) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java new file mode 100644 index 0000000000000..97f3dc588ecc5 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.execution.sort; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.RecordBinaryComparator; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.collection.unsafe.sort.*; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form. + */ +public class RecordBinaryComparatorSuite { + + private final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + + private MemoryBlock dataPage; + private long pageCursor; + + private LongArray array; + private int pos; + + @Before + public void beforeEach() { + // Only compare between two input rows. + array = consumer.allocateArray(2); + pos = 0; + + dataPage = memoryManager.allocatePage(4096, consumer); + pageCursor = dataPage.getBaseOffset(); + } + + @After + public void afterEach() { + consumer.freePage(dataPage); + dataPage = null; + pageCursor = 0; + + consumer.freeArray(array); + array = null; + pos = 0; + } + + private void insertRow(UnsafeRow row) { + Object recordBase = row.getBaseObject(); + long recordOffset = row.getBaseOffset(); + int recordLength = row.getSizeInBytes(); + + Object baseObject = dataPage.getBaseObject(); + assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size()); + long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor); + UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength); + pageCursor += recordLength; + + assert(pos < 2); + array.set(pos, recordAddress); + pos++; + } + + private int compare(int index1, int index2) { + Object baseObject = dataPage.getBaseObject(); + + long recordAddress1 = array.get(index1); + long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize; + int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize); + + long recordAddress2 = array.get(index2); + long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize; + int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize); + + return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject, + baseOffset2, recordLength2); + } + + private final RecordComparator binaryComparator = new RecordBinaryComparator(); + + // Compute the most compact size for UnsafeRow's backing data. + private int computeSizeInBytes(int originalSize) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + (originalSize + 7) / 8 * 8; + } + + // Compute the relative offset of variable-length values. + private long relativeOffset(int numFields) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + numFields * 8L; + } + + @Test + public void testBinaryComparatorForSingleColumnRow() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 42); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForMultipleColumnRow() throws Exception { + int numFields = 5; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setDouble(i, i * 3.14); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row2.setDouble(i, 198.7 / (i + 1)); + } + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForArrayColumn() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1}); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes())); + row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes()); + Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22}); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes())); + row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes()); + Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForMixedColumns() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UTF8String str1 = UTF8String.fromString("Milk tea"); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes())); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes()); + Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UTF8String str2 = UTF8String.fromString("Java"); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes())); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes()); + Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForNullColumns() throws Exception { + int numFields = 3; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setNullAt(i); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields - 1; i++) { + row2.setNullAt(i); + } + row2.setDouble(numFields - 1, 3.14); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 11L + Integer.MAX_VALUE); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, Long.MIN_VALUE); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, 0); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 172e5d5eebcbe..5602310219a74 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -20,33 +20,75 @@ import java.io.IOException; import java.util.*; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { +public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters { + public class ReadSupport extends JavaSimpleReadSupport { + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new AdvancedScanConfigBuilder(); + } + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; + List res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaRangeInputPartition(0, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 4) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 9) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + } + + return res.stream().toArray(InputPartition[]::new); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; + return new AdvancedReaderFactory(requiredSchema); + } + } + + public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, + SupportsPushDownFilters, SupportsPushDownRequiredColumns { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); public Filter[] filters = new Filter[0]; @Override - public StructType readSchema() { - return requiredSchema; + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; } @Override - public void pruneColumns(StructType requiredSchema) { - this.requiredSchema = requiredSchema; + public StructType readSchema() { + return requiredSchema; } @Override @@ -79,78 +121,54 @@ public Filter[] pushedFilters() { } @Override - public List> createDataReaderFactories() { - List> res = new ArrayList<>(); - - Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { - lowerBound = (Integer) f.value(); - break; - } - } - } - - if (lowerBound == null) { - res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); - } else if (lowerBound < 4) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); - } else if (lowerBound < 9) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); - } - - return res; + public ScanConfig build() { + return this; } } - static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { - private int start; - private int end; - private StructType requiredSchema; + static class AdvancedReaderFactory implements PartitionReaderFactory { + StructType requiredSchema; - JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { - this.start = start; - this.end = end; + AdvancedReaderFactory(StructType requiredSchema) { this.requiredSchema = requiredSchema; } @Override - public DataReader createDataReader() { - return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } - @Override - public Row get() { - Object[] values = new Object[requiredSchema.size()]; - for (int i = 0; i < values.length; i++) { - if ("i".equals(requiredSchema.apply(i).name())) { - values[i] = start; - } else if ("j".equals(requiredSchema.apply(i).name())) { - values[i] = -start; + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = current; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -current; + } + } + return new GenericInternalRow(values); } - } - return new GenericRow(values); - } - @Override - public void close() throws IOException { + @Override + public void close() throws IOException { + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java deleted file mode 100644 index c55093768105b..0000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; - - -public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader, SupportsScanColumnarBatch { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> createBatchDataReaderFactories() { - return java.util.Arrays.asList( - new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); - } - } - - static class JavaBatchDataReaderFactory - implements DataReaderFactory, DataReader { - private int start; - private int end; - - private static final int BATCH_SIZE = 20; - - private OnHeapColumnVector i; - private OnHeapColumnVector j; - private ColumnarBatch batch; - - JavaBatchDataReaderFactory(int start, int end) { - this.start = start; - this.end = end; - } - - @Override - public DataReader createDataReader() { - this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - ColumnVector[] vectors = new ColumnVector[2]; - vectors[0] = i; - vectors[1] = j; - this.batch = new ColumnarBatch(vectors); - return this; - } - - @Override - public boolean next() { - i.reset(); - j.reset(); - int count = 0; - while (start < end && count < BATCH_SIZE) { - i.putInt(count, start); - j.putInt(count, -start); - start += 1; - count += 1; - } - - if (count == 0) { - return false; - } else { - batch.setNumRows(count); - return true; - } - } - - @Override - public ColumnarBatch get() { - return batch; - } - - @Override - public void close() throws IOException { - batch.close(); - } - } - - - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java new file mode 100644 index 0000000000000..28a9330398310 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { + + class ReadSupport extends JavaSimpleReadSupport { + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 50); + partitions[1] = new JavaRangeInputPartition(50, 90); + return partitions; + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new ColumnarReaderFactory(); + } + } + + static class ColumnarReaderFactory implements PartitionReaderFactory { + private static final int BATCH_SIZE = 20; + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + + @Override + public PartitionReader createReader(InputPartition partition) { + throw new UnsupportedOperationException(""); + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + ColumnarBatch batch = new ColumnarBatch(vectors); + + return new PartitionReader() { + private int current = p.start; + + @Override + public boolean next() throws IOException { + i.reset(); + j.reset(); + int count = 0; + while (current < p.end && count < BATCH_SIZE) { + i.putInt(count, current); + j.putInt(count, -current); + current += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + }; + } + } + + @Override + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 32fad59b97ff6..18a11dde82198 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -19,38 +19,34 @@ import java.io.IOException; import java.util.Arrays; -import java.util.List; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; -import org.apache.spark.sql.types.StructType; -public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { +public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader, SupportsReportPartitioning { - private final StructType schema = new StructType().add("a", "int").add("b", "int"); + class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { @Override - public StructType readSchema() { - return schema; + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); + partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); + return partitions; } @Override - public List> createDataReaderFactories() { - return java.util.Arrays.asList( - new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new SpecificReaderFactory(); } @Override - public Partitioning outputPartitioning() { + public Partitioning outputPartitioning(ScanConfig config) { return new MyPartitioning(); } } @@ -66,48 +62,53 @@ public int numPartitions() { public boolean satisfy(Distribution distribution) { if (distribution instanceof ClusteredDistribution) { String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; - return Arrays.asList(clusteredCols).contains("a"); + return Arrays.asList(clusteredCols).contains("i"); } return false; } } - static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { - private int[] i; - private int[] j; - private int current = -1; + static class SpecificInputPartition implements InputPartition { + int[] i; + int[] j; - SpecificDataReaderFactory(int[] i, int[] j) { + SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; } + } - @Override - public boolean next() throws IOException { - current += 1; - return current < i.length; - } - - @Override - public Row get() { - return new GenericRow(new Object[] {i[current], j[current]}); - } - - @Override - public void close() throws IOException { - - } + static class SpecificReaderFactory implements PartitionReaderFactory { @Override - public DataReader createDataReader() { - return this; + public PartitionReader createReader(InputPartition partition) { + SpecificInputPartition p = (SpecificInputPartition) partition; + return new PartitionReader() { + private int current = -1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.i.length; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {p.i[current], p.j[current]}); + } + + @Override + public void close() throws IOException { + + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 048d078dfaac4..cc9ac04a0dad3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,38 +17,39 @@ package test.org.apache.spark.sql.sources.v2; -import java.util.List; - -import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { +public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader { + class ReadSupport extends JavaSimpleReadSupport { private final StructType schema; - Reader(StructType schema) { + ReadSupport(StructType schema) { this.schema = schema; } @Override - public StructType readSchema() { + public StructType fullSchema() { return schema; } @Override - public List> createDataReaderFactories() { - return java.util.Collections.emptyList(); + public InputPartition[] planInputPartitions(ScanConfig config) { + return new InputPartition[0]; } } @Override - public DataSourceReader createReader(StructType schema, DataSourceOptions options) { - return new Reader(schema); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + throw new IllegalArgumentException("requires a user-supplied schema"); + } + + @Override + public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return new ReadSupport(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 96f55b8a76811..2cdbba84ec4a4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,70 +17,26 @@ package test.org.apache.spark.sql.sources.v2; -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); +import org.apache.spark.sql.sources.v2.reader.*; - @Override - public StructType readSchema() { - return schema; - } +public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - @Override - public List> createDataReaderFactories() { - return java.util.Arrays.asList( - new JavaSimpleDataReaderFactory(0, 5), - new JavaSimpleDataReaderFactory(5, 10)); - } - } - - static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { - private int start; - private int end; - - JavaSimpleDataReaderFactory(int start, int end) { - this.start = start; - this.end = end; - } + class ReadSupport extends JavaSimpleReadSupport { @Override - public DataReader createDataReader() { - return new JavaSimpleDataReaderFactory(start - 1, end); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } - - @Override - public Row get() { - return new GenericRow(new Object[] {start, -start}); - } - - @Override - public void close() throws IOException { - + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 5); + partitions[1] = new JavaRangeInputPartition(5, 10); + return partitions; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java new file mode 100644 index 0000000000000..685f9b9747e85 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +abstract class JavaSimpleReadSupport implements BatchReadSupport { + + @Override + public StructType fullSchema() { + return new StructType().add("i", "int").add("j", "int"); + } + + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new JavaNoopScanConfigBuilder(fullSchema()); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new JavaSimpleReaderFactory(); + } +} + +class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { + + private StructType schema; + + JavaNoopScanConfigBuilder(StructType schema) { + this.schema = schema; + } + + @Override + public ScanConfig build() { + return this; + } + + @Override + public StructType readSchema() { + return schema; + } +} + +class JavaSimpleReaderFactory implements PartitionReaderFactory { + + @Override + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {current, -current}); + } + + @Override + public void close() throws IOException { + + } + }; + } +} + +class JavaRangeInputPartition implements InputPartition { + int start; + int end; + + JavaRangeInputPartition(int start, int end) { + this.start = start; + this.end = end; + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java deleted file mode 100644 index c3916e0b370b5..0000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.StructType; - -public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader, SupportsScanUnsafeRow { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> createUnsafeRowReaderFactories() { - return java.util.Arrays.asList( - new JavaUnsafeRowDataReaderFactory(0, 5), - new JavaUnsafeRowDataReaderFactory(5, 10)); - } - } - - static class JavaUnsafeRowDataReaderFactory - implements DataReaderFactory, DataReader { - private int start; - private int end; - private UnsafeRow row; - - JavaUnsafeRowDataReaderFactory(int start, int end) { - this.start = start; - this.end = end; - this.row = new UnsafeRow(2); - row.pointTo(new byte[8 * 3], 8 * 3); - } - - @Override - public DataReader createDataReader() { - return new JavaUnsafeRowDataReaderFactory(start - 1, end); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } - - @Override - public UnsafeRow get() { - row.setInt(0, start); - row.setInt(1, -start); - return row; - } - - @Override - public void close() throws IOException { - - } - } - - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); - } -} diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 46b38bed1c0fb..a36b0cfa6ff18 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider org.apache.spark.sql.streaming.sources.FakeNoWrite -org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback +org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql index d3f928751757c..83c32a5bf2435 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql @@ -13,10 +13,8 @@ DROP VIEW view1; -- Test scenario with Global Temp view CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1; SELECT * FROM global_temp.view1; --- TODO: Support this scenario SELECT global_temp.view1.* FROM global_temp.view1; SELECT i1 FROM global_temp.view1; --- TODO: Support this scenario SELECT global_temp.view1.i1 FROM global_temp.view1; SELECT view1.i1 FROM global_temp.view1; SELECT a.i1 FROM global_temp.view1 AS a; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql index 79e90ad3de91d..d001185a73931 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql @@ -14,9 +14,7 @@ SELECT i1 FROM mydb1.t1; SELECT t1.i1 FROM t1; SELECT t1.i1 FROM mydb1.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1; USE mydb2; @@ -24,7 +22,6 @@ SELECT i1 FROM t1; SELECT i1 FROM mydb1.t1; SELECT t1.i1 FROM t1; SELECT t1.i1 FROM mydb1.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1; -- Scenario: resolve fully qualified table name in star expansion @@ -34,7 +31,6 @@ SELECT mydb1.t1.* FROM mydb1.t1; SELECT t1.* FROM mydb1.t1; USE mydb2; SELECT t1.* FROM t1; --- TODO: Support this scenario SELECT mydb1.t1.* FROM mydb1.t1; SELECT t1.* FROM mydb1.t1; SELECT a.* FROM mydb1.t1 AS a; @@ -47,21 +43,17 @@ CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3) SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2); --- TODO: Support this scenario SELECT * FROM mydb1.t3 WHERE c1 IN (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2); -- Scenario: column resolution scenarios in join queries SET spark.sql.crossJoin.enabled = true; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1, mydb2.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1; USE mydb2; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1, mydb1.t1; SET spark.sql.crossJoin.enabled = false; @@ -75,12 +67,10 @@ SELECT t5.t5.i1 FROM mydb1.t5; SELECT t5.i1 FROM mydb1.t5; SELECT t5.* FROM mydb1.t5; SELECT t5.t5.* FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.t5.i1 FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.t5.i2 FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.* FROM mydb1.t5; +SELECT mydb1.t5.* FROM t5; -- Cleanup and Reset USE default; diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 547c2bef02b24..4950a4b7a4e5a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -27,3 +27,36 @@ select current_date = current_date(), current_timestamp = current_timestamp(), a select a, b from ttf2 order by a, current_date; select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); + +select from_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select from_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select from_utc_timestamp(null, 'PST'); + +select from_utc_timestamp('2015-07-24 00:00:00', null); + +select from_utc_timestamp(null, null); + +select from_utc_timestamp(cast(0 as timestamp), 'PST'); + +select from_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select to_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select to_utc_timestamp(null, 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', null); + +select to_utc_timestamp(null, null); + +select to_utc_timestamp(cast(0 as timestamp), 'PST'); + +select to_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +-- SPARK-23715: the input of to/from_utc_timestamp can not have timezone +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); + +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql new file mode 100644 index 0000000000000..e28f0721a6449 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql @@ -0,0 +1,160 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1); +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1), (2), (2), (3), (5), (5), (null) AS tab2(c1); +CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (2, 2) + AS tab3(k, v); +CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + (1, 2), + (2, 3), + (2, 2), + (2, 2), + (2, 20) + AS tab4(k, v); + +-- Basic EXCEPT ALL +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2; + +-- MINUS ALL (synonym for EXCEPT) +SELECT * FROM tab1 +MINUS ALL +SELECT * FROM tab2; + +-- EXCEPT ALL same table in both branches +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 IS NOT NULL; + +-- Empty left relation +SELECT * FROM tab1 WHERE c1 > 5 +EXCEPT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6; + +-- Type Coerced ExceptAll +SELECT * FROM tab1 +EXCEPT ALL +SELECT CAST(1 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +EXCEPT ALL +SELECT array(1); + +-- Basic +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4; + +-- Basic +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3; + +-- EXCEPT ALL + INTERSECT +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +INTERSECT DISTINCT +SELECT * FROM tab4; + +-- EXCEPT ALL + EXCEPT +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab3 +EXCEPT ALL +SELECT k, v FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Using MINUS ALL +SELECT * FROM tab3 +MINUS ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +MINUS DISTINCT +SELECT * FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +EXCEPT DISTINCT +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Join under except all. Should produce empty resultset since both left and right sets +-- are same. +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k); + +-- Join under except all (2) +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab4.v AS k, + tab3.k AS v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k); + +-- Group by under ExceptAll +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k; + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; +DROP VIEW IF EXISTS tab3; +DROP VIEW IF EXISTS tab4; diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql new file mode 100644 index 0000000000000..9adf5d70056e2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql @@ -0,0 +1,21 @@ +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c; + +select extract(year from c) from t; + +select extract(quarter from c) from t; + +select extract(month from c) from t; + +select extract(week from c) from t; + +select extract(day from c) from t; + +select extract(dayofweek from c) from t; + +select extract(hour from c) from t; + +select extract(minute from c) from t; + +select extract(second from c) from t; + +select extract(not_supported from c) from t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index 3594283505280..6bbde9f38d657 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -13,5 +13,41 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((a)); -- SPARK-17849: grouping set throws NPE #3 SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((c)); +-- Group sets without explicit group by +SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1); +-- Group sets without group by and with grouping +SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1); + +-- Mutiple grouping within a grouping set +SELECT c1, c2, Sum(c3), grouping__id +FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) +GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) +HAVING GROUPING__ID > 1; + +-- Group sets without explicit group by +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2); + +-- Mutiple grouping within a grouping set +SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)); + +-- complex expression in grouping sets +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b)); + +-- complex expression in grouping sets +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b)); + +-- more query constructs with grouping sets +SELECT c1 AS col1, c2 AS col2 +FROM (VALUES (1, 2), (3, 2)) t(c1, c2) +GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) ) +HAVING col2 IS NOT NULL +ORDER BY -col1; + +-- negative tests - must have at least one grouping expression +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP; + +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE; + +SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()); diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql new file mode 100644 index 0000000000000..02ad5e3538689 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -0,0 +1,85 @@ +create or replace temporary view nested as values + (1, array(32, 97), array(array(12, 99), array(123, 42), array(1))), + (2, array(77, -76), array(array(6, 96, 65), array(-1, -2))), + (3, array(12), array(array(17))) + as t(x, ys, zs); + +-- Only allow lambda's in higher order functions. +select upper(x -> x) as v; + +-- Identity transform an array +select transform(zs, z -> z) as v from nested; + +-- Transform an array +select transform(ys, y -> y * y) as v from nested; + +-- Transform an array with index +select transform(ys, (y, i) -> y + i) as v from nested; + +-- Transform an array with reference +select transform(zs, z -> concat(ys, z)) as v from nested; + +-- Transform an array to an array of 0's +select transform(ys, 0) as v from nested; + +-- Transform a null array +select transform(cast(null as array), x -> x + 1) as v; + +-- Filter. +select filter(ys, y -> y > 30) as v from nested; + +-- Filter a null array +select filter(cast(null as array), y -> true) as v; + +-- Filter nested arrays +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested; + +-- Aggregate. +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested; + +-- Aggregate average. +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested; + +-- Aggregate nested arrays +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested; + +-- Aggregate a null array +select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v; + +-- Check for element existence +select exists(ys, y -> y > 30) as v from nested; + +-- Check for element existence in a null array +select exists(cast(null as array), y -> y > 30) as v; + +-- Zip with array +select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested; + +-- Zip with array with concat +select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v; + +-- Zip with array coalesce +select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v; + +create or replace temporary view nested as values + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) + as t(x, ys); + +-- Identity Transform Keys in a map +select transform_keys(ys, (k, v) -> k) as v from nested; + +-- Transform Keys in a map by adding constant +select transform_keys(ys, (k, v) -> k + 1) as v from nested; + +-- Transform Keys in a map using values +select transform_keys(ys, (k, v) -> k + v) as v from nested; + +-- Identity Transform values in a map +select transform_values(ys, (k, v) -> v) as v from nested; + +-- Transform values in a map by adding constant +select transform_values(ys, (k, v) -> v + 1) as v from nested; + +-- Transform values in a map using values +select transform_values(ys, (k, v) -> k + v) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql new file mode 100644 index 0000000000000..b0b2244048caa --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql @@ -0,0 +1,160 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v); +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v); + +-- Basic INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- INTERSECT ALL same table in both branches +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1; + +-- Empty left relation +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3; + +-- Type Coerced INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2; + +-- Basic +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- Chain of different `set operations +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +; + +-- Chain of different `set operations +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +; + +-- test use parenthesis to control order of evaluation +( + ( + ( + SELECT * FROM tab1 + EXCEPT + SELECT * FROM tab2 + ) + EXCEPT + SELECT * FROM tab1 + ) + INTERSECT ALL + SELECT * FROM tab2 +) +; + +-- Join under intersect all +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Join under intersect all (2) +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Group by under intersect all +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k; + +-- Test pre spark2.4 behaviour of set operation precedence +-- All the set operators are given equal precedence and are evaluated +-- from left to right as they appear in the query. + +-- Set the property +SET spark.sql.legacy.setopsPrecedence.enabled= true; + +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT +SELECT * FROM tab2; + +-- Restore the property +SET spark.sql.legacy.setopsPrecedence.enabled = false; + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql index 8afa3270f4de4..2e6a5f362a8fa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index fea069eac4d48..0cf370c13e8c0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -31,3 +31,23 @@ CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; -- Clean up DROP VIEW IF EXISTS jsonTable; + +-- from_json - complex types +select from_json('{"a":1, "b":2}', 'map'); +select from_json('{"a":1, "b":"2"}', 'struct'); + +-- infer schema of json literal +select schema_of_json('{"c1":0, "c2":[1]}'); +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')); + +-- from_json - array type +select from_json('[1, 2, 3]', 'array'); +select from_json('[1, "2", 3]', 'array'); +select from_json('[1, 2, null]', 'array'); + +select from_json('[{"a": 1}, {"a":2}]', 'array>'); +select from_json('{"a": 1}', 'array>'); +select from_json('[null, {"a":2}]', 'array>'); + +select from_json('[{"a": 1}, {"b":2}]', 'array>'); +select from_json('[{"a": 1}, 2]', 'array>'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index f21912a042716..e33cd819f281f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,3 +1,5 @@ +-- Disable global limit parallel +set spark.sql.limit.flatGlobalLimit=false; -- limit on various data types SELECT * FROM testdata LIMIT 2; @@ -13,6 +15,11 @@ SELECT * FROM testdata LIMIT CAST(1 AS int); SELECT * FROM testdata LIMIT -1; SELECT * FROM testData TABLESAMPLE (-1 ROWS); + +SELECT * FROM testdata LIMIT CAST(1 AS INT); +-- evaluated limit must not be null +SELECT * FROM testdata LIMIT CAST(NULL AS INT); + -- limit must be foldable SELECT * FROM testdata LIMIT key > 3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql index 71a50157b766c..e0abeda3eb44f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + create temporary view nt1 as select * from values ("one", 1), ("two", 2), diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql index cdc6c81e10047..ce09c21568f13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + -- SPARK-17099: Incorrect result when HAVING clause is added to group by query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (-234), (145), (367), (975), (298) diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql new file mode 100644 index 0000000000000..1f607b334dc18 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -0,0 +1,289 @@ +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s); + +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s); + +-- pivot courses +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot years with no subquery +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot courses with multiple aggregations +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column and with multiple aggregations on different columns +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple group by columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +); + +-- pivot on join query with multiple aggregations on different columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple columns in one aggregation +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with aliases and projection +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +); + +-- pivot with projection and value aliases +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +); + +-- pivot years with non-aggregate function +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with one of the expressions as non-aggregate function +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +); + +-- pivot with unresolvable columns +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with complex aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +); + +-- pivot with invalid arguments in aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on multiple pivot columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +); + +-- pivot on multiple pivot columns with aliased values +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +); + +-- pivot on multiple pivot columns with values of wrong data types +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +); + +-- pivot with unresolvable values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +); + +-- pivot with non-literal values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +); + +-- pivot on join query with columns of complex data types +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on multiple pivot columns with agg columns of complex data types +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +); + +-- pivot on pivot column of array type +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +); + +-- pivot on multiple pivot columns containing array type +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +); + +-- pivot on pivot column of struct type +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +); + +-- pivot on multiple pivot columns containing struct type +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +); + +-- pivot on pivot column of map type +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +); + +-- pivot on multiple pivot columns containing map type +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql index cc4ed64affec7..cefc3fe6272ab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql @@ -1,5 +1,9 @@ -- Tests EXISTS subquery support. Tests Exists subquery -- used in Joins (Both when joins occurs in outer and suquery blocks) +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES (100, "emp 1", date "2005-01-01", 100.00D, 10), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql new file mode 100644 index 0000000000000..f4ffc20086386 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql @@ -0,0 +1,14 @@ +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1); +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2); +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2); + +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b); +-- Invalid query, see SPARK-24341 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b); + +-- Aliasing is needed as a workaround for SPARK-24443 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b); +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index 880175fd7add0..22f3eafd6a02d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for IN JOINS in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql index a40ee082ba3b9..a862e0985b20c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql @@ -1,6 +1,9 @@ -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- Disable global limit optimization +set spark.sql.limit.flatGlobalLimit=false; + create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -97,4 +100,4 @@ WHERE t1d NOT IN (SELECT t2d LIMIT 1) GROUP BY t1b ORDER BY t1b NULLS last -LIMIT 1; \ No newline at end of file +LIMIT 1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index e09b91f18de0a..4f8ca8bfb27c1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for not-in-joins in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql new file mode 100644 index 0000000000000..8eea84f4f5272 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql @@ -0,0 +1,39 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- This file has the same test cases as not-in-unit-tests-multi-column.sql with literals instead of +-- subqueries. Small changes have been made to the literals to make them typecheck. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +-- Case 1 (not possible to write a literal with no rows, so we ignore it.) +-- (subquery is empty -> row is returned) + +-- Cases 2, 3 and 4 are currently broken, so I have commented them out here. +-- Filed https://issues.apache.org/jira/browse/SPARK-24395 to fix and restore these test cases. + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql new file mode 100644 index 0000000000000..9f8dc7fca3b94 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql @@ -0,0 +1,98 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- +-- Test cases for multi-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | do filter columns contain null? | a = c? | b = d? | row included in result? | +-- | 1 | empty | * | * | * | yes | +-- | 2 | 1+ row has null for all columns | * | * | * | no | +-- | 3 | no row has null for all columns | (yes, yes) | * | * | no | +-- | 4 | no row has null for all columns | (no, yes) | yes | * | no | +-- | 5 | no row has null for all columns | (no, yes) | no | * | yes | +-- | 6 | no | (no, no) | yes | yes | no | +-- | 7 | no | (no, no) | _ | _ | yes | +-- +-- This can be generalized to include more tests for more columns, but it covers the main cases +-- when there is more than one column. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d); + + -- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +; + + -- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +; + + -- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql new file mode 100644 index 0000000000000..b261363d1dde7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql @@ -0,0 +1,42 @@ +-- Unit tests for simple NOT IN with a literal expression of a single column +-- +-- More information can be found in not-in-unit-tests-single-column.sql. +-- This file has the same test cases as not-in-unit-tests-single-column.sql with literals instead of +-- subqueries. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null); + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2); + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2); + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql new file mode 100644 index 0000000000000..2cc08e10acf67 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql @@ -0,0 +1,123 @@ +-- Unit tests for simple NOT IN predicate subquery across a single column. +-- +-- ``col NOT IN expr'' is quite difficult to reason about. There are many edge cases, some of the +-- rules are confusing to the uninitiated, and precedence and treatment of null values is plain +-- unintuitive. To make this simpler to understand, I've come up with a plain English way of +-- describing the expected behavior of this query. +-- +-- - If the subquery is empty (i.e. returns no rows), the row should be returned, regardless of +-- whether the filtered columns include nulls. +-- - If the subquery contains a result with all columns null, then the row should not be returned. +-- - If for all non-null filter columns there exists a row in the subquery in which each column +-- either +-- 1. is equal to the corresponding filter column or +-- 2. is null +-- then the row should not be returned. (This includes the case where all filter columns are +-- null.) +-- - Otherwise, the row should be returned. +-- +-- Using these rules, we can come up with a set of test cases for single-column and multi-column +-- NOT IN test cases. +-- +-- Test cases for single-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | is a null? | a = c? | row with a included in result? | +-- | 1 | empty | | | yes | +-- | 2 | yes | | | no | +-- | 3 | no | yes | | no | +-- | 4 | no | no | yes | no | +-- | 5 | no | no | no | yes | +-- +-- There are also some considerations around correlated subqueries. Correlated subqueries can +-- cause cases 2, 3, or 4 to be reduced to case 1 by limiting the number of rows returned by the +-- subquery, so the row from the parent table should always be included in the output. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +; + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +; + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +; + + -- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql index b15f4da81dd93..95b115a8dd094 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql @@ -13,6 +13,14 @@ CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) AS t3(t3a, t3b, t3c); +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES + (CAST(1 AS DOUBLE), CAST(2 AS STRING), CAST(3 AS STRING)) +AS t1(t4a, t4b, t4c); + +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES + (CAST(1 AS DECIMAL(18, 0)), CAST(2 AS STRING), CAST(3 AS BIGINT)) +AS t1(t5a, t5b, t5c); + -- TC 01.01 SELECT ( SELECT max(t2b), min(t2b) @@ -44,4 +52,10 @@ WHERE (t1a, t1b) IN (SELECT t2a FROM t2 WHERE t1a = t2a); - +-- TC 01.05 +SELECT * FROM t4 +WHERE +(t4a, t4b, t4c) IN (SELECT t5a, + t5b, + t5c + FROM t5); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql new file mode 100644 index 0000000000000..99729c007b104 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql @@ -0,0 +1,11 @@ +SELECT array_join(array(true, false), ', '); +SELECT array_join(array(2Y, 1Y), ', '); +SELECT array_join(array(2S, 1S), ', '); +SELECT array_join(array(2, 1), ', '); +SELECT array_join(array(2L, 1L), ', '); +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', '); +SELECT array_join(array(2.0D, 1.0D), ', '); +SELECT array_join(array(float(2.0), float(1.0)), ', '); +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', '); +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', '); +SELECT array_join(array('a', 'b'), ', '); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index db00a18f2e7e9..99f46dd19d0e2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -148,6 +148,8 @@ SELECT (tinyint_array1 || smallint_array2) ts_array, (smallint_array1 || int_array2) si_array, (int_array1 || bigint_array2) ib_array, + (bigint_array1 || decimal_array2) bd_array, + (decimal_array1 || double_array2) dd_array, (double_array1 || float_array2) df_array, (string_array1 || data_array2) std_array, (timestamp_array1 || string_array2) tst_array, diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index 9be7fcdadfea8..28a0e20c0f495 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -40,12 +40,14 @@ select 10.3000 * 3.0; select 10.30000 * 30.0; select 10.300000000000000000 * 3.000000000000000000; select 10.300000000000000000 * 3.0000000000000000000; +select 2.35E10 * 1.0; -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; +select 1.2345678901234567890E30 * 1.2345678901234567890E25; -- arithmetic operations causing a precision loss are truncated select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; @@ -67,12 +69,14 @@ select 10.3000 * 3.0; select 10.30000 * 30.0; select 10.300000000000000000 * 3.000000000000000000; select 10.300000000000000000 * 3.0000000000000000000; +select 2.35E10 * 1.0; -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; +select 1.2345678901234567890E30 * 1.2345678901234567890E25; -- arithmetic operations causing a precision loss return NULL select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql new file mode 100644 index 0000000000000..1727ee725db2e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql @@ -0,0 +1,78 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +); + +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map2, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql new file mode 100644 index 0000000000000..69da67fc66fc0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -0,0 +1,95 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +); + +-- Concatenate maps of the same type +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps; + +-- Concatenate maps of different types +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(bigint_map1, decimal_map2) bd_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps; + +-- Concatenate map of incompatible types 1 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps; + +-- Concatenate map of incompatible types 2 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps; + +-- Concatenate map of incompatible types 3 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps; + +-- Concatenate map of incompatible types 4 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps; + +-- Concatenate map of incompatible types 5 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql new file mode 100644 index 0000000000000..92c7e26e3add2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql @@ -0,0 +1,56 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x); + +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px; + + +select id, regr_count(y,x) over (partition by px) from t1 order by id; diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index e57d69eaad033..6da1b9b49b226 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -35,6 +35,17 @@ FROM (SELECT col AS col SELECT col FROM p3) T1) T2; +-- SPARK-24012 Union of map and other compatible columns. +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1; + +-- SPARK-24012 Union of array and other compatible columns. +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1; + + -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index 539f673c9d679..9fc97f0c39149 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -72,7 +72,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 9 @@ -81,7 +81,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 10 @@ -90,7 +90,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +Reference 'mydb1.t1.i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 11 @@ -99,7 +99,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb1.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 12 @@ -108,7 +108,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb1.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 13 @@ -125,7 +125,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb2.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 15 @@ -134,7 +134,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb2.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 16 @@ -143,7 +143,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb2.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 17 @@ -152,7 +152,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb2.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 18 @@ -161,7 +161,7 @@ SELECT db1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`db1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +cannot resolve '`db1.t1.i1`' given input columns: [mydb2.t1.i1, mydb2.t1.i1]; line 1 pos 7 -- !query 19 @@ -186,7 +186,7 @@ SELECT mydb1.t1 FROM t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`mydb1.t1`' given input columns: [mydb1.t1.i1]; line 1 pos 7 -- !query 22 @@ -204,7 +204,7 @@ SELECT t1 FROM mydb1.t1 struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '`t1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`t1`' given input columns: [mydb1.t1.i1]; line 1 pos 7 -- !query 24 @@ -221,7 +221,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [mydb2.t1.i1]; line 1 pos 7 -- !query 26 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 2092119600954..3d8fb661afe55 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -85,10 +85,9 @@ struct -- !query 10 SELECT global_temp.view1.* FROM global_temp.view1 -- !query 10 schema -struct<> +struct -- !query 10 output -org.apache.spark.sql.AnalysisException -cannot resolve 'global_temp.view1.*' given input columns 'i1'; +1 -- !query 11 @@ -102,10 +101,9 @@ struct -- !query 12 SELECT global_temp.view1.i1 FROM global_temp.view1 -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.AnalysisException -cannot resolve '`global_temp.view1.i1`' given input columns: [view1.i1]; line 1 pos 7 +1 -- !query 13 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index e10f516ad6e5b..73e3fdc08232c 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query 0 @@ -93,19 +93,17 @@ struct -- !query 11 SELECT mydb1.t1.i1 FROM t1 -- !query 11 schema -struct<> +struct -- !query 11 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 12 SELECT mydb1.t1.i1 FROM mydb1.t1 -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 13 @@ -151,10 +149,9 @@ struct -- !query 18 SELECT mydb1.t1.i1 FROM mydb1.t1 -- !query 18 schema -struct<> +struct -- !query 18 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 19 @@ -176,10 +173,9 @@ struct -- !query 21 SELECT mydb1.t1.* FROM mydb1.t1 -- !query 21 schema -struct<> +struct -- !query 21 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' given input columns 'i1'; +1 -- !query 22 @@ -209,10 +205,9 @@ struct -- !query 25 SELECT mydb1.t1.* FROM mydb1.t1 -- !query 25 schema -struct<> +struct -- !query 25 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' given input columns 'i1'; +1 -- !query 26 @@ -267,10 +262,9 @@ struct SELECT * FROM mydb1.t3 WHERE c1 IN (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2) -- !query 32 schema -struct<> +struct -- !query 32 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t4.c3`' given input columns: [t4.c2, t4.c3]; line 2 pos 42 +4 1 -- !query 33 @@ -284,19 +278,17 @@ spark.sql.crossJoin.enabled true -- !query 34 SELECT mydb1.t1.i1 FROM t1, mydb2.t1 -- !query 34 schema -struct<> +struct -- !query 34 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 35 SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 -- !query 35 schema -struct<> +struct -- !query 35 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 36 @@ -310,10 +302,9 @@ struct<> -- !query 37 SELECT mydb1.t1.i1 FROM t1, mydb1.t1 -- !query 37 schema -struct<> +struct -- !query 37 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 38 @@ -399,40 +390,37 @@ struct -- !query 48 SELECT mydb1.t5.t5.i1 FROM mydb1.t5 -- !query 48 schema -struct<> +struct -- !query 48 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i1`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 +2 -- !query 49 SELECT mydb1.t5.t5.i2 FROM mydb1.t5 -- !query 49 schema -struct<> +struct -- !query 49 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i2`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 +3 -- !query 50 SELECT mydb1.t5.* FROM mydb1.t5 -- !query 50 schema -struct<> +struct> -- !query 50 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t5.*' given input columns 'i1, t5'; +1 {"i1":2,"i2":3} -- !query 51 -USE default +SELECT mydb1.t5.* FROM t5 -- !query 51 schema -struct<> +struct> -- !query 51 output - +1 {"i1":2,"i2":3} -- !query 52 -DROP DATABASE mydb1 CASCADE +USE default -- !query 52 schema struct<> -- !query 52 output @@ -440,8 +428,16 @@ struct<> -- !query 53 -DROP DATABASE mydb2 CASCADE +DROP DATABASE mydb1 CASCADE -- !query 53 schema struct<> -- !query 53 output + + +-- !query 54 +DROP DATABASE mydb2 CASCADE +-- !query 54 schema +struct<> +-- !query 54 output + diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 4e1cfa6e48c1c..9eede305dbdcc 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 26 -- !query 0 @@ -82,9 +82,138 @@ struct 1 2 2 3 + -- !query 9 select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') --- !query 3 schema +-- !query 9 schema struct --- !query 3 output +-- !query 9 output 5 3 5 NULL 4 + + +-- !query 10 +select from_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 10 schema +struct +-- !query 10 output +2015-07-23 17:00:00 + + +-- !query 11 +select from_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 11 schema +struct +-- !query 11 output +2015-01-23 16:00:00 + + +-- !query 12 +select from_utc_timestamp(null, 'PST') +-- !query 12 schema +struct +-- !query 12 output +NULL + + +-- !query 13 +select from_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 13 schema +struct +-- !query 13 output +NULL + + +-- !query 14 +select from_utc_timestamp(null, null) +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select from_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 15 schema +struct +-- !query 15 output +1969-12-31 08:00:00 + + +-- !query 16 +select from_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 16 schema +struct +-- !query 16 output +2015-01-23 16:00:00 + + +-- !query 17 +select to_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 17 schema +struct +-- !query 17 output +2015-07-24 07:00:00 + + +-- !query 18 +select to_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 18 schema +struct +-- !query 18 output +2015-01-24 08:00:00 + + +-- !query 19 +select to_utc_timestamp(null, 'PST') +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select to_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select to_utc_timestamp(null, null) +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select to_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 22 schema +struct +-- !query 22 output +1970-01-01 00:00:00 + + +-- !query 23 +select to_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 23 schema +struct +-- !query 23 output +2015-01-24 08:00:00 + + +-- !query 24 +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 25 schema +struct +-- !query 25 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 51dac111029e8..8ba69c698b551 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -57,6 +57,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -89,7 +91,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -122,7 +126,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -147,7 +153,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -180,7 +188,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -205,7 +215,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -230,7 +242,9 @@ Database default Table t Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 -Partition Statistics 1054 bytes, 2 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1144 bytes, 2 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 8c908b7625056..79390cb424444 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -282,6 +282,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 @@ -311,6 +313,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out new file mode 100644 index 0000000000000..01091a2f751ce --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -0,0 +1,346 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 27 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1), (2), (2), (3), (5), (5), (null) AS tab2(c1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (2, 2) + AS tab3(k, v) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + (1, 2), + (2, 3), + (2, 2), + (2, 2), + (2, 20) + AS tab4(k, v) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct +-- !query 4 output +0 +2 +2 +NULL + + +-- !query 5 +SELECT * FROM tab1 +MINUS ALL +SELECT * FROM tab2 +-- !query 5 schema +struct +-- !query 5 output +0 +2 +2 +NULL + + +-- !query 6 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 IS NOT NULL +-- !query 6 schema +struct +-- !query 6 output +0 +2 +2 +NULL +NULL + + +-- !query 7 +SELECT * FROM tab1 WHERE c1 > 5 +EXCEPT ALL +SELECT * FROM tab2 +-- !query 7 schema +struct +-- !query 7 output + + + +-- !query 8 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6 +-- !query 8 schema +struct +-- !query 8 output +0 +1 +2 +2 +2 +2 +3 +NULL +NULL + + +-- !query 9 +SELECT * FROM tab1 +EXCEPT ALL +SELECT CAST(1 AS BIGINT) +-- !query 9 schema +struct +-- !query 9 output +0 +2 +2 +2 +2 +3 +NULL +NULL + + +-- !query 10 +SELECT * FROM tab1 +EXCEPT ALL +SELECT array(1) +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +ExceptAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; + + +-- !query 11 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +-- !query 11 schema +struct +-- !query 11 output +1 2 +1 3 + + +-- !query 12 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +-- !query 12 schema +struct +-- !query 12 output +2 2 +2 20 + + +-- !query 13 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +INTERSECT DISTINCT +SELECT * FROM tab4 +-- !query 13 schema +struct +-- !query 13 output +2 2 +2 20 + + +-- !query 14 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 14 schema +struct +-- !query 14 output + + + +-- !query 15 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 15 schema +struct +-- !query 15 output +1 3 + + +-- !query 16 +SELECT k FROM tab3 +EXCEPT ALL +SELECT k, v FROM tab4 +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +ExceptAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 17 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 17 schema +struct +-- !query 17 output +1 3 + + +-- !query 18 +SELECT * FROM tab3 +MINUS ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +MINUS DISTINCT +SELECT * FROM tab4 +-- !query 18 schema +struct +-- !query 18 output +1 3 + + +-- !query 19 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +EXCEPT DISTINCT +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 19 schema +struct +-- !query 19 output + + + +-- !query 20 +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +-- !query 20 schema +struct +-- !query 20 output + + + +-- !query 21 +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab4.v AS k, + tab3.k AS v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +-- !query 21 schema +struct +-- !query 21 output +1 2 +1 2 +1 2 +2 20 +2 20 +2 3 +2 3 + + +-- !query 22 +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k +-- !query 22 schema +struct +-- !query 22 output +3 + + +-- !query 23 +DROP VIEW IF EXISTS tab1 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +DROP VIEW IF EXISTS tab2 +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +DROP VIEW IF EXISTS tab3 +-- !query 25 schema +struct<> +-- !query 25 output + + + +-- !query 26 +DROP VIEW IF EXISTS tab4 +-- !query 26 schema +struct<> +-- !query 26 output + diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out new file mode 100644 index 0000000000000..160e4c7d78455 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -0,0 +1,96 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select extract(year from c) from t +-- !query 1 schema +struct +-- !query 1 output +2011 + + +-- !query 2 +select extract(quarter from c) from t +-- !query 2 schema +struct +-- !query 2 output +2 + + +-- !query 3 +select extract(month from c) from t +-- !query 3 schema +struct +-- !query 3 output +5 + + +-- !query 4 +select extract(week from c) from t +-- !query 4 schema +struct +-- !query 4 output +18 + + +-- !query 5 +select extract(day from c) from t +-- !query 5 schema +struct +-- !query 5 output +6 + + +-- !query 6 +select extract(dayofweek from c) from t +-- !query 6 schema +struct +-- !query 6 output +6 + + +-- !query 7 +select extract(hour from c) from t +-- !query 7 schema +struct +-- !query 7 output +7 + + +-- !query 8 +select extract(minute from c) from t +-- !query 8 schema +struct +-- !query 8 output +8 + + +-- !query 9 +select extract(second from c) from t +-- !query 9 schema +struct +-- !query 9 output +9 + + +-- !query 10 +select extract(not_supported from c) from t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +Literals of type 'NOT_SUPPORTED' are currently not supported.(line 1, pos 7) + +== SQL == +select extract(not_supported from c) from t +-------^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out index edb38a52b7514..34ab09c5e3bba 100644 --- a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 15 -- !query 0 @@ -40,3 +40,127 @@ struct NULL NULL 3 1 NULL NULL 6 1 NULL NULL 9 1 + + +-- !query 4 +SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) +-- !query 4 schema +struct +-- !query 4 output +x 10 +y 20 + + +-- !query 5 +SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) +-- !query 5 schema +struct +-- !query 5 output +x 10 0 +y 20 0 + + +-- !query 6 +SELECT c1, c2, Sum(c3), grouping__id +FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) +GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) +HAVING GROUPING__ID > 1 +-- !query 6 schema +struct +-- !query 6 output +NULL a 10 2 +NULL b 20 2 + + +-- !query 7 +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2) +-- !query 7 schema +struct +-- !query 7 output +0 +0 +1 +1 + + +-- !query 8 +SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)) +-- !query 8 schema +struct +-- !query 8 output +-1 +-1 +-3 +-3 + + +-- !query 9 +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b)) +-- !query 9 schema +struct<(a + b):int,b:int,sum(c):bigint> +-- !query 9 output +2 NULL 1 +4 NULL 2 +NULL 1 1 +NULL 2 2 + + +-- !query 10 +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b)) +-- !query 10 schema +struct<(a + b):int,b:int,sum(c):bigint> +-- !query 10 output +2 NULL 2 +4 NULL 4 +NULL 1 1 +NULL 2 2 + + +-- !query 11 +SELECT c1 AS col1, c2 AS col2 +FROM (VALUES (1, 2), (3, 2)) t(c1, c2) +GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) ) +HAVING col2 IS NOT NULL +ORDER BY -col1 +-- !query 11 schema +struct +-- !query 11 output +3 2 +1 2 + + +-- !query 12 +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'ROLLUP' expecting (line 1, pos 53) + +== SQL == +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP +-----------------------------------------------------^^^ + + +-- !query 13 +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'CUBE' expecting (line 1, pos 53) + +== SQL == +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE +-----------------------------------------------------^^^ + + +-- !query 14 +SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +expression '`c1`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out new file mode 100644 index 0000000000000..32d20d1b73415 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -0,0 +1,255 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 27 + + +-- !query 0 +create or replace temporary view nested as values + (1, array(32, 97), array(array(12, 99), array(123, 42), array(1))), + (2, array(77, -76), array(array(6, 96, 65), array(-1, -2))), + (3, array(12), array(array(17))) + as t(x, ys, zs) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select upper(x -> x) as v +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +A lambda function should only be used in a higher order function. However, its class is org.apache.spark.sql.catalyst.expressions.Upper, which is not a higher order function.; line 1 pos 7 + + +-- !query 2 +select transform(zs, z -> z) as v from nested +-- !query 2 schema +struct>> +-- !query 2 output +[[12,99],[123,42],[1]] +[[17]] +[[6,96,65],[-1,-2]] + + +-- !query 3 +select transform(ys, y -> y * y) as v from nested +-- !query 3 schema +struct> +-- !query 3 output +[1024,9409] +[144] +[5929,5776] + + +-- !query 4 +select transform(ys, (y, i) -> y + i) as v from nested +-- !query 4 schema +struct> +-- !query 4 output +[12] +[32,98] +[77,-75] + + +-- !query 5 +select transform(zs, z -> concat(ys, z)) as v from nested +-- !query 5 schema +struct>> +-- !query 5 output +[[12,17]] +[[32,97,12,99],[32,97,123,42],[32,97,1]] +[[77,-76,6,96,65],[77,-76,-1,-2]] + + +-- !query 6 +select transform(ys, 0) as v from nested +-- !query 6 schema +struct> +-- !query 6 output +[0,0] +[0,0] +[0] + + +-- !query 7 +select transform(cast(null as array), x -> x + 1) as v +-- !query 7 schema +struct> +-- !query 7 output +NULL + + +-- !query 8 +select filter(ys, y -> y > 30) as v from nested +-- !query 8 schema +struct> +-- !query 8 output +[32,97] +[77] +[] + + +-- !query 9 +select filter(cast(null as array), y -> true) as v +-- !query 9 schema +struct> +-- !query 9 output +NULL + + +-- !query 10 +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested +-- !query 10 schema +struct>> +-- !query 10 output +[[96,65],[]] +[[99],[123],[]] +[[]] + + +-- !query 11 +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested +-- !query 11 schema +struct +-- !query 11 output +131 +15 +5 + + +-- !query 12 +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested +-- !query 12 schema +struct +-- !query 12 output +0.5 +12.0 +64.5 + + +-- !query 13 +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested +-- !query 13 schema +struct> +-- !query 13 output +[1010880,8] +[17] +[4752,20664,1] + + +-- !query 14 +select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select exists(ys, y -> y > 30) as v from nested +-- !query 15 schema +struct +-- !query 15 output +false +true +true + + +-- !query 16 +select exists(cast(null as array), y -> y > 30) as v +-- !query 16 schema +struct +-- !query 16 output +NULL + + +-- !query 17 +select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested +-- !query 17 schema +struct> +-- !query 17 output +[13] +[34,99,null] +[80,-74] + + +-- !query 18 +select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v +-- !query 18 schema +struct> +-- !query 18 output +["ad","be","cf"] + + +-- !query 19 +select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v +-- !query 19 schema +struct> +-- !query 19 output +["a",null,"f"] + + +-- !query 20 +create or replace temporary view nested as values + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) + as t(x, ys) +-- !query 20 schema +struct<> +-- !query 20 output + + +-- !query 21 +select transform_keys(ys, (k, v) -> k) as v from nested +-- !query 21 schema +struct> +-- !query 21 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 22 +select transform_keys(ys, (k, v) -> k + 1) as v from nested +-- !query 22 schema +struct> +-- !query 22 output +{2:1,3:2,4:3} +{5:4,6:5,7:6} + + +-- !query 23 +select transform_keys(ys, (k, v) -> k + v) as v from nested +-- !query 23 schema +struct> +-- !query 23 output +{10:5,12:6,8:4} +{2:1,4:2,6:3} + + +-- !query 24 +select transform_values(ys, (k, v) -> v) as v from nested +-- !query 24 schema +struct> +-- !query 24 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 25 +select transform_values(ys, (k, v) -> v + 1) as v from nested +-- !query 25 schema +struct> +-- !query 25 output +{1:2,2:3,3:4} +{4:5,5:6,6:7} + + +-- !query 26 +select transform_values(ys, (k, v) -> k + v) as v from nested +-- !query 26 schema +struct> +-- !query 26 output +{1:2,2:4,3:6} +{4:8,5:10,6:12} diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out new file mode 100644 index 0000000000000..63dd56ce468bc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -0,0 +1,307 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 22 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 2 schema +struct +-- !query 2 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 3 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1 +-- !query 3 schema +struct +-- !query 3 output +1 2 +1 2 +1 3 +1 3 + + +-- !query 4 +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3 +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT) +-- !query 6 schema +struct +-- !query 6 output +1 2 + + +-- !query 7 +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; + + +-- !query 8 +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 9 +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 9 schema +struct +-- !query 9 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 10 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 10 schema +struct +-- !query 10 output +1 2 +1 2 +1 3 +2 3 +NULL NULL +NULL NULL + + +-- !query 11 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 11 schema +struct +-- !query 11 output +1 3 + + +-- !query 12 +( + ( + ( + SELECT * FROM tab1 + EXCEPT + SELECT * FROM tab2 + ) + EXCEPT + SELECT * FROM tab1 + ) + INTERSECT ALL + SELECT * FROM tab2 +) +-- !query 12 schema +struct +-- !query 12 output + + + +-- !query 13 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 13 schema +struct +-- !query 13 output +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +2 3 + + +-- !query 14 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 14 schema +struct +-- !query 14 output + + + +-- !query 15 +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k +-- !query 15 schema +struct +-- !query 15 output +2 +3 +NULL + + +-- !query 16 +SET spark.sql.legacy.setopsPrecedence.enabled= true +-- !query 16 schema +struct +-- !query 16 output +spark.sql.legacy.setopsPrecedence.enabled true + + +-- !query 17 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 17 schema +struct +-- !query 17 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 18 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT +SELECT * FROM tab2 +-- !query 18 schema +struct +-- !query 18 output +1 2 +2 3 +NULL NULL + + +-- !query 19 +SET spark.sql.legacy.setopsPrecedence.enabled = false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.legacy.setopsPrecedence.enabled false + + +-- !query 20 +DROP VIEW IF EXISTS tab1 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +DROP VIEW IF EXISTS tab2 +-- !query 21 schema +struct<> +-- !query 21 output + diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 581dddc89d0bb..7444cdbef96e4 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 38 -- !query 0 @@ -24,7 +24,7 @@ Extended Usage: {"a":1,"b":2} > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} - > SELECT to_json(array(named_struct('a', 1, 'b', 2)); + > SELECT to_json(array(named_struct('a', 1, 'b', 2))); [{"a":1,"b":2}] > SELECT to_json(map('a', named_struct('b', 1))); {"a":{"b":1}} @@ -120,7 +120,7 @@ select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 12 @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1 and 2; Found: 0; line 1 pos 7 -- !query 13 @@ -183,7 +183,7 @@ select from_json('{"a":1}', 1) struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Expected a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 -- !query 18 @@ -216,7 +216,7 @@ select from_json('{"a":1}', 'a INT', map('mode', 1)) struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 21 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 -- !query 22 @@ -258,3 +258,99 @@ DROP VIEW IF EXISTS jsonTable struct<> -- !query 25 output + + +-- !query 26 +select from_json('{"a":1, "b":2}', 'map') +-- !query 26 schema +struct> +-- !query 26 output +{"a":1,"b":2} + + +-- !query 27 +select from_json('{"a":1, "b":"2"}', 'struct') +-- !query 27 schema +struct> +-- !query 27 output +{"a":1,"b":"2"} + + +-- !query 28 +select schema_of_json('{"c1":0, "c2":[1]}') +-- !query 28 schema +struct +-- !query 28 output +struct> + + +-- !query 29 +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) +-- !query 29 schema +struct>> +-- !query 29 output +{"c1":[1,2,3]} + + +-- !query 30 +select from_json('[1, 2, 3]', 'array') +-- !query 30 schema +struct> +-- !query 30 output +[1,2,3] + + +-- !query 31 +select from_json('[1, "2", 3]', 'array') +-- !query 31 schema +struct> +-- !query 31 output +NULL + + +-- !query 32 +select from_json('[1, 2, null]', 'array') +-- !query 32 schema +struct> +-- !query 32 output +[1,2,null] + + +-- !query 33 +select from_json('[{"a": 1}, {"a":2}]', 'array>') +-- !query 33 schema +struct>> +-- !query 33 output +[{"a":1},{"a":2}] + + +-- !query 34 +select from_json('{"a": 1}', 'array>') +-- !query 34 schema +struct>> +-- !query 34 output +[{"a":1}] + + +-- !query 35 +select from_json('[null, {"a":2}]', 'array>') +-- !query 35 schema +struct>> +-- !query 35 output +[null,{"a":2}] + + +-- !query 36 +select from_json('[{"a": 1}, {"b":2}]', 'array>') +-- !query 36 schema +struct>> +-- !query 36 output +[{"a":1},{"b":2}] + + +-- !query 37 +select from_json('[{"a": 1}, 2]', 'array>') +-- !query 37 schema +struct>> +-- !query 37 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 146abe6cbd058..187f3bd6858fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,109 +1,134 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 15 -- !query 0 -SELECT * FROM testdata LIMIT 2 +set spark.sql.limit.flatGlobalLimit=false -- !query 0 schema -struct +struct -- !query 0 output +spark.sql.limit.flatGlobalLimit false + + +-- !query 1 +SELECT * FROM testdata LIMIT 2 +-- !query 1 schema +struct +-- !query 1 output 1 1 2 2 --- !query 1 +-- !query 2 SELECT * FROM arraydata LIMIT 2 --- !query 1 schema +-- !query 2 schema struct,nestedarraycol:array>> --- !query 1 output +-- !query 2 output [1,2,3] [[1,2,3]] [2,3,4] [[2,3,4]] --- !query 2 +-- !query 3 SELECT * FROM mapdata LIMIT 2 --- !query 2 schema +-- !query 3 schema struct> --- !query 2 output +-- !query 3 output {1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} {1:"a2",2:"b2",3:"c2",4:"d2"} --- !query 3 +-- !query 4 SELECT * FROM testdata LIMIT 2 + 1 --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output 1 1 2 2 3 3 --- !query 4 +-- !query 5 SELECT * FROM testdata LIMIT CAST(1 AS int) --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output 1 1 --- !query 5 +-- !query 6 SELECT * FROM testdata LIMIT -1 --- !query 5 schema +-- !query 6 schema struct<> --- !query 5 output +-- !query 6 output org.apache.spark.sql.AnalysisException The limit expression must be equal to or greater than 0, but got -1; --- !query 6 +-- !query 7 SELECT * FROM testData TABLESAMPLE (-1 ROWS) --- !query 6 schema +-- !query 7 schema struct<> --- !query 6 output +-- !query 7 output org.apache.spark.sql.AnalysisException The limit expression must be equal to or greater than 0, but got -1; --- !query 7 +-- !query 8 +SELECT * FROM testdata LIMIT CAST(1 AS INT) +-- !query 8 schema +struct +-- !query 8 output +1 1 + + +-- !query 9 +SELECT * FROM testdata LIMIT CAST(NULL AS INT) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +The evaluated limit expression must not be null, but got CAST(NULL AS INT); + + +-- !query 10 SELECT * FROM testdata LIMIT key > 3 --- !query 7 schema +-- !query 10 schema struct<> --- !query 7 output +-- !query 10 output org.apache.spark.sql.AnalysisException The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); --- !query 8 +-- !query 11 SELECT * FROM testdata LIMIT true --- !query 8 schema +-- !query 11 schema struct<> --- !query 8 output +-- !query 11 output org.apache.spark.sql.AnalysisException The limit expression must be integer type, but got boolean; --- !query 9 +-- !query 12 SELECT * FROM testdata LIMIT 'a' --- !query 9 schema +-- !query 12 schema struct<> --- !query 9 output +-- !query 12 output org.apache.spark.sql.AnalysisException The limit expression must be integer type, but got string; --- !query 10 +-- !query 13 SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 --- !query 10 schema +-- !query 13 schema struct --- !query 10 output +-- !query 13 output 4 --- !query 11 +-- !query 14 SELECT * FROM testdata WHERE key < 3 LIMIT ALL --- !query 11 schema +-- !query 14 schema struct --- !query 11 output +-- !query 14 output 1 1 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index b8c91dc8b59a4..7f301614523b2 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -147,7 +147,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890 @@ -159,7 +159,7 @@ struct<> -- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890.0 @@ -379,7 +379,7 @@ struct<> -- !query 39 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38(line 1, pos 7) +decimal can only support precision up to 38(line 1, pos 7) == SQL == select 1.20E-38BD diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out new file mode 100644 index 0000000000000..2dd92930f92aa --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -0,0 +1,478 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 31 + + +-- !query 0 +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 3 schema +struct +-- !query 3 output +2012 15000 20000 +2013 48000 30000 + + +-- !query 4 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 4 schema +struct +-- !query 4 output +Java 20000 30000 +dotNET 15000 48000 + + +-- !query 5 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 5 schema +struct +-- !query 5 output +2012 15000 7500.0 20000 20000.0 +2013 48000 48000.0 30000 30000.0 + + +-- !query 6 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 6 schema +struct +-- !query 6 output +63000 50000 + + +-- !query 7 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +) +-- !query 7 schema +struct +-- !query 7 output +63000 2012 50000 2012 + + +-- !query 8 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +) +-- !query 8 schema +struct +-- !query 8 output +Java 2012 20000 NULL +Java 2013 NULL 30000 +dotNET 2012 15000 NULL +dotNET 2013 NULL 48000 + + +-- !query 9 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +) +-- !query 9 schema +struct +-- !query 9 output +2012 15000 1 20000 1 +2013 48000 2 30000 2 + + +-- !query 10 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +) +-- !query 10 schema +struct +-- !query 10 output +2012 15000 20000 +2013 96000 60000 + + +-- !query 11 +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +) +-- !query 11 schema +struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> +-- !query 11 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 12 +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +) +-- !query 12 schema +struct +-- !query 12 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 13 +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +) +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; + + +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, but '__auto_generated_subquery_name.`year`' did not appear in any aggregate function.; + + +-- !query 15 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 + + +-- !query 16 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +) +-- !query 16 schema +struct +-- !query 16 output +2012 15000 7501.0 20000 20001.0 +2013 48000 48001.0 30000 30001.0 + + +-- !query 17 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +) +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; + + +-- !query 18 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +) +-- !query 18 schema +struct +-- !query 18 output +1 15000 NULL +2 NULL 30000 + + +-- !query 19 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +) +-- !query 19 schema +struct +-- !query 19 output +2012 NULL 20000 +2013 48000 NULL + + +-- !query 20 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +Invalid pivot value 'dotNET': value data type string does not match pivot column data type struct; + + +-- !query 21 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +) +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`s`' given input columns: [coursesales.course, coursesales.year, coursesales.earnings]; line 4 pos 15 + + +-- !query 22 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +) +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +Literal expressions required for pivot values, found 'course#x'; + + +-- !query 23 +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +) +-- !query 23 schema +struct,Java:array> +-- !query 23 output +2012 [1,1] [1,1] +2013 [2,2] [2,2] + + +-- !query 24 +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +) +-- !query 24 schema +struct,[2013, Java]:array> +-- !query 24 output +2012 [1,1] NULL +2013 NULL [2,2] + + +-- !query 25 +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +) +-- !query 25 schema +struct +-- !query 25 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 26 +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +) +-- !query 26 schema +struct +-- !query 26 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 27 +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +) +-- !query 27 schema +struct +-- !query 27 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 28 +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +) +-- !query 28 schema +struct +-- !query 28 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 29 +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +) +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'm#x'. Pivot columns must be comparable.; + + +-- !query 30 +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable.; diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 975bb06124744..abeb7e18f031e 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -178,6 +178,8 @@ struct -- !query 14 output showdb show_t1 false Partition Values: [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1 +Created Time [not included in comparison] +Last Access [not included in comparison] -- !query 15 diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index d5f8705a35ed6..7b3dc84388889 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -36,14 +36,14 @@ struct -- !query 3 output == Parsed Logical Plan == 'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] -+- 'SubqueryAlias __auto_generated_subquery_name ++- 'SubqueryAlias `__auto_generated_subquery_name` +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] +- 'UnresolvedTableValuedFunction range, [10] == Analyzed Logical Plan == col: string Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] -+- SubqueryAlias __auto_generated_subquery_name ++- SubqueryAlias `__auto_generated_subquery_name` +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] +- Range (0, 10, step=1, splits=None) diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out new file mode 100644 index 0000000000000..088db55d66406 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out @@ -0,0 +1,70 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b) +-- !query 3 schema +struct<1:int> +-- !query 3 output + + + +-- !query 4 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Cannot analyze (named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())). +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2 +#columns in right hand side: 1 +Left side columns: +[tab_a.`a1`, tab_a.`b1`] +Right side columns: +[`named_struct(a2, a2, b2, b2)`]; + + +-- !query 5 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b) +-- !query 5 schema +struct +-- !query 5 output +1 + + +-- !query 6 +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b) +-- !query 6 schema +struct +-- !query 6 output +3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index 71ca1f8649475..9eb5b3383e734 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -1,8 +1,16 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 9 -- !query 0 +set spark.sql.limit.flatGlobalLimit=false +-- !query 0 schema +struct +-- !query 0 output +spark.sql.limit.flatGlobalLimit false + + +-- !query 1 create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -17,13 +25,13 @@ create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) --- !query 0 schema +-- !query 1 schema struct<> --- !query 0 output +-- !query 1 output --- !query 1 +-- !query 2 create temporary view t2 as select * from values ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -39,13 +47,13 @@ create temporary view t2 as select * from values ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) --- !query 1 schema +-- !query 2 schema struct<> --- !query 1 output +-- !query 2 output --- !query 2 +-- !query 3 create temporary view t3 as select * from values ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), @@ -60,27 +68,27 @@ create temporary view t3 as select * from values ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) --- !query 2 schema +-- !query 3 schema struct<> --- !query 2 output +-- !query 3 output --- !query 3 +-- !query 4 SELECT * FROM t1 WHERE t1a IN (SELECT t2a FROM t2 WHERE t1d = t2d) LIMIT 2 --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 4 +-- !query 5 SELECT * FROM t1 WHERE t1c IN (SELECT t2c @@ -88,16 +96,16 @@ WHERE t1c IN (SELECT t2c WHERE t2b >= 8 LIMIT 2) LIMIT 4 --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 5 +-- !query 6 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -108,29 +116,29 @@ WHERE t1d IN (SELECT t2d GROUP BY t1b ORDER BY t1b DESC NULLS FIRST LIMIT 1 --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 1 NULL --- !query 6 +-- !query 7 SELECT * FROM t1 WHERE t1b NOT IN (SELECT t2b FROM t2 WHERE t2b > 6 LIMIT 2) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 --- !query 7 +-- !query 8 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -141,7 +149,7 @@ WHERE t1d NOT IN (SELECT t2d GROUP BY t1b ORDER BY t1b NULLS last LIMIT 1 --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output 1 6 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out new file mode 100644 index 0000000000000..a16e98af9a417 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out @@ -0,0 +1,54 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 1 schema +struct +-- !query 1 output +NULL 1 + + +-- !query 2 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 3 schema +struct +-- !query 3 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out new file mode 100644 index 0000000000000..aa5f64b8ebf55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out @@ -0,0 +1,134 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 9 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 +NULL NULL + + +-- !query 3 +-- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 6 schema +struct +-- !query 6 output +NULL 1 + + +-- !query 7 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 7 schema +struct +-- !query 7 output + + + +-- !query 8 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 8 schema +struct +-- !query 8 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out new file mode 100644 index 0000000000000..446447e890449 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out @@ -0,0 +1,69 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 5 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null) +-- !query 1 schema +struct +-- !query 1 output + + + +-- !query 2 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6) +-- !query 4 schema +struct +-- !query 4 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out new file mode 100644 index 0000000000000..f58ebeacc2872 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out @@ -0,0 +1,149 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 + + +-- !query 3 +-- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +-- !query 6 schema +struct +-- !query 6 output +2 3 + + +-- !query 7 +-- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 7 schema +struct +-- !query 7 output +2 3 +4 5 +NULL 1 + + +-- !query 8 +-- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 8 schema +struct +-- !query 8 output +NULL 1 + + +-- !query 9 +-- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 9 schema +struct +-- !query 9 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 2586f26f71c35..e49978ddb1ce2 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -109,8 +109,8 @@ struct<> org.apache.spark.sql.AnalysisException Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses: Aggregate [min(outer(t2a#x)) AS min(outer())#x] -+- SubqueryAlias t3 ++- SubqueryAlias `t3` +- Project [t3a#x, t3b#x, t3c#x] - +- SubqueryAlias t3 + +- SubqueryAlias `t3` +- LocalRelation [t3a#x, t3b#x, t3c#x] ; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index 70aeb9373f3c7..c52e5706deeee 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 10 -- !query 0 @@ -33,6 +33,26 @@ struct<> -- !query 3 +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES + (CAST(1 AS DOUBLE), CAST(2 AS STRING), CAST(3 AS STRING)) +AS t1(t4a, t4b, t4c) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES + (CAST(1 AS DECIMAL(18, 0)), CAST(2 AS STRING), CAST(3 AS BIGINT)) +AS t1(t5a, t5b, t5c) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 SELECT ( SELECT max(t2b), min(t2b) FROM t2 @@ -40,14 +60,14 @@ SELECT GROUP BY t2.t2b ) FROM t1 --- !query 3 schema +-- !query 5 schema struct<> --- !query 3 output +-- !query 5 output org.apache.spark.sql.AnalysisException Scalar subquery must return only one column, but got 2; --- !query 4 +-- !query 6 SELECT ( SELECT max(t2b), min(t2b) FROM t2 @@ -55,50 +75,72 @@ SELECT GROUP BY t2.t2b ) FROM t1 --- !query 4 schema +-- !query 6 schema struct<> --- !query 4 output +-- !query 6 output org.apache.spark.sql.AnalysisException Scalar subquery must return only one column, but got 2; --- !query 5 +-- !query 7 SELECT * FROM t1 WHERE t1a IN (SELECT t2a, t2b FROM t2 WHERE t1a = t2a) --- !query 5 schema +-- !query 7 schema struct<> --- !query 5 output +-- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +Cannot analyze (t1.`t1a` IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 1. -#columns in right hand side: 2. +#columns in left hand side: 1 +#columns in right hand side: 2 Left side columns: -[t1.`t1a`]. +[t1.`t1a`] Right side columns: -[t2.`t2a`, t2.`t2b`].; +[t2.`t2a`, t2.`t2b`]; --- !query 6 +-- !query 8 SELECT * FROM T1 WHERE (t1a, t1b) IN (SELECT t2a FROM t2 WHERE t1a = t2a) --- !query 6 schema +-- !query 8 schema struct<> --- !query 6 output +-- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +Cannot analyze (named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 2. -#columns in right hand side: 1. +#columns in left hand side: 2 +#columns in right hand side: 1 Left side columns: -[t1.`t1a`, t1.`t1b`]. +[t1.`t1a`, t1.`t1b`] Right side columns: -[t2.`t2a`].; +[t2.`t2a`]; + + +-- !query 9 +SELECT * FROM t4 +WHERE +(t4a, t4b, t4c) IN (SELECT t5a, + t5b, + t5c + FROM t5) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(named_struct('t4a', t4.`t4a`, 't4b', t4.`t4b`, 't4c', t4.`t4c`) IN (listquery()))' due to data type mismatch: +The data type of one or more elements in the left hand side of an IN subquery +is not compatible with the data type of the output of the subquery +Mismatched columns: +[(t4.`t4a`:double, t5.`t5a`:decimal(18,0)), (t4.`t4c`:string, t5.`t5c`:bigint)] +Left side: +[double, string, string]. +Right side: +[decimal(18,0), string, bigint].; diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index a8bc6faf11262..94af9181225d6 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -83,8 +83,13 @@ select * from range(1, null) -- !query 6 schema struct<> -- !query 6 output -java.lang.IllegalArgumentException -Invalid arguments for resolved function: 1, null +org.apache.spark.sql.AnalysisException +error: table-valued function range with alternatives: + (end: long) + (start: long, end: long) + (start: long, end: long, step: long) + (start: long, end: long, step: long, numPartitions: integer) +cannot be applied to: (integer, null); line 1 pos 14 -- !query 7 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out new file mode 100644 index 0000000000000..b23a62dacef7c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +SELECT array_join(array(true, false), ', ') +-- !query 0 schema +struct +-- !query 0 output +true, false + + +-- !query 1 +SELECT array_join(array(2Y, 1Y), ', ') +-- !query 1 schema +struct +-- !query 1 output +2, 1 + + +-- !query 2 +SELECT array_join(array(2S, 1S), ', ') +-- !query 2 schema +struct +-- !query 2 output +2, 1 + + +-- !query 3 +SELECT array_join(array(2, 1), ', ') +-- !query 3 schema +struct +-- !query 3 output +2, 1 + + +-- !query 4 +SELECT array_join(array(2L, 1L), ', ') +-- !query 4 schema +struct +-- !query 4 output +2, 1 + + +-- !query 5 +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ') +-- !query 5 schema +struct +-- !query 5 output +9223372036854775809, 9223372036854775808 + + +-- !query 6 +SELECT array_join(array(2.0D, 1.0D), ', ') +-- !query 6 schema +struct +-- !query 6 output +2.0, 1.0 + + +-- !query 7 +SELECT array_join(array(float(2.0), float(1.0)), ', ') +-- !query 7 schema +struct +-- !query 7 output +2.0, 1.0 + + +-- !query 8 +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ') +-- !query 8 schema +struct +-- !query 8 output +2016-03-14, 2016-03-13 + + +-- !query 9 +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ') +-- !query 9 schema +struct +-- !query 9 output +2016-11-15 20:54:00, 2016-11-12 20:54:00 + + +-- !query 10 +SELECT array_join(array('a', 'b'), ', ') +-- !query 10 schema +struct +-- !query 10 output +a, b diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 62befc5ca0f15..6c6d3110d7d0d 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 14 -- !query 0 @@ -306,12 +306,14 @@ SELECT (tinyint_array1 || smallint_array2) ts_array, (smallint_array1 || int_array2) si_array, (int_array1 || bigint_array2) ib_array, + (bigint_array1 || decimal_array2) bd_array, + (decimal_array1 || double_array2) dd_array, (double_array1 || float_array2) df_array, (string_array1 || data_array2) std_array, (timestamp_array1 || string_array2) tst_array, (string_array1 || int_array2) sti_array FROM various_arrays -- !query 13 schema -struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +struct,si_array:array,ib_array:array,bd_array:array,dd_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> -- !query 13 output -[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,9223372036854775808,9223372036854775809] [9.223372036854776E18,9.223372036854776E18,3.0,4.0] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index 6bfdb84548d4d..cbf44548b3cce 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 36 +-- Number of queries: 40 -- !query 0 @@ -114,190 +114,222 @@ struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.00000000000000000 -- !query 13 -select (5e36 + 0.1) + 5e36 +select 2.35E10 * 1.0 -- !query 13 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> -- !query 13 output -NULL +23500000000 -- !query 14 -select (-4e36 - 0.1) - 7e36 +select (5e36 + 0.1) + 5e36 -- !query 14 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 14 output NULL -- !query 15 -select 12345678901234567890.0 * 12345678901234567890.0 +select (-4e36 - 0.1) - 7e36 -- !query 15 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 15 output NULL -- !query 16 -select 1e35 / 0.1 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 16 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 16 output NULL -- !query 17 -select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +select 1e35 / 0.1 -- !query 17 schema -struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> -- !query 17 output -10012345678912345678912345678911.246907 +NULL -- !query 18 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 18 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> -- !query 18 output -138698367904130467.654320988515622621 +NULL -- !query 19 -select 12345678912345.123456789123 / 0.000000012345678 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 -- !query 19 schema -struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> -- !query 19 output -1000000073899961059796.725866332 +10012345678912345678912345678911.246907 -- !query 20 -set spark.sql.decimalOperations.allowPrecisionLoss=false +select 123456789123456789.1234567890 * 1.123456789123456789 -- !query 20 schema -struct +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> -- !query 20 output -spark.sql.decimalOperations.allowPrecisionLoss false +138698367904130467.654320988515622621 -- !query 21 -select id, a+b, a-b, a*b, a/b from decimals_test order by id +select 12345678912345.123456789123 / 0.000000012345678 -- !query 21 schema -struct +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> -- !query 21 output -1 1099 -899 NULL 0.1001001001001001 -2 24690.246 0 NULL 1 -3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 -4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 +1000000073899961059796.725866332 -- !query 22 -select id, a*10, b/10 from decimals_test order by id +set spark.sql.decimalOperations.allowPrecisionLoss=false -- !query 22 schema -struct +struct -- !query 22 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.1123456789123456789 +spark.sql.decimalOperations.allowPrecisionLoss false -- !query 23 -select 10.3 * 3.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 23 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +struct -- !query 23 output -30.9 +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 -- !query 24 -select 10.3000 * 3.0 +select id, a*10, b/10 from decimals_test order by id -- !query 24 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +struct -- !query 24 output -30.9 +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 -- !query 25 -select 10.30000 * 30.0 +select 10.3 * 3.0 -- !query 25 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 25 output -309 +30.9 -- !query 26 -select 10.300000000000000000 * 3.000000000000000000 +select 10.3000 * 3.0 -- !query 26 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 26 output 30.9 -- !query 27 -select 10.300000000000000000 * 3.0000000000000000000 +select 10.30000 * 30.0 -- !query 27 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> -- !query 27 output -NULL +309 -- !query 28 -select (5e36 + 0.1) + 5e36 +select 10.300000000000000000 * 3.000000000000000000 -- !query 28 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> -- !query 28 output -NULL +30.9 -- !query 29 -select (-4e36 - 0.1) - 7e36 +select 10.300000000000000000 * 3.0000000000000000000 -- !query 29 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> -- !query 29 output NULL -- !query 30 -select 12345678901234567890.0 * 12345678901234567890.0 +select 2.35E10 * 1.0 -- !query 30 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> -- !query 30 output -NULL +23500000000 -- !query 31 -select 1e35 / 0.1 +select (5e36 + 0.1) + 5e36 -- !query 31 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 31 output NULL -- !query 32 -select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +select (-4e36 - 0.1) - 7e36 -- !query 32 schema -struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 32 output NULL -- !query 33 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 33 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 33 output NULL -- !query 34 -select 12345678912345.123456789123 / 0.000000012345678 +select 1e35 / 0.1 -- !query 34 schema -struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> -- !query 34 output NULL -- !query 35 -drop table decimals_test +select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 35 schema -struct<> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> -- !query 35 output +NULL + + +-- !query 36 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 36 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +-- !query 36 output +NULL + + +-- !query 37 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 37 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 37 output +NULL + + +-- !query 38 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 38 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +-- !query 38 output +NULL + + +-- !query 39 +drop table decimals_test +-- !query 39 schema +struct<> +-- !query 39 output diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out new file mode 100644 index 0000000000000..35740094ba53e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -0,0 +1,179 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 16 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 1 schema +struct>> +-- !query 1 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 2 +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 2 schema +struct>> +-- !query 2 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 3 +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 3 schema +struct>> +-- !query 3 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 4 +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 4 schema +struct>> +-- !query 4 output +{2.0:{"k":2.0,"v1":1.0,"v2":1.0}} + + +-- !query 5 +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 + + +-- !query 6 +SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 6 schema +struct>> +-- !query 6 output +{2:{"k":2,"v1":null,"v2":1},922337203685477897945456575809789456:{"k":922337203685477897945456575809789456,"v1":922337203685477897945456575809789456,"v2":null}} + + +-- !query 7 +SELECT map_zip_with(decimal_map1, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 7 schema +struct>> +-- !query 7 output +{2.0:{"k":2.0,"v1":null,"v2":1.0},9.223372036854779E35:{"k":9.223372036854779E35,"v1":922337203685477897945456575809789456,"v2":null}} + + +-- !query 8 +SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 + + +-- !query 9 +SELECT map_zip_with(decimal_map2, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 9 schema +struct>> +-- !query 9 output +{2.0:{"k":2.0,"v1":null,"v2":1.0},9.223372036854778:{"k":9.223372036854778,"v1":9.22337203685477897945456575809789456,"v2":null}} + + +-- !query 10 +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 10 schema +struct>> +-- !query 10 output +{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} + + +-- !query 11 +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 11 schema +struct>> +-- !query 11 output +{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} + + +-- !query 12 +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 12 schema +struct>> +-- !query 12 output +{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} + + +-- !query 13 +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 13 schema +struct>> +-- !query 13 output +{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} + + +-- !query 14 +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 14 schema +struct,struct,v1:array,v2:array>>> +-- !query 14 output +{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} + + +-- !query 15 +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 15 schema +struct,struct,v1:struct,v2:struct>>> +-- !query 15 output +{{"col1":1,"col2":2}:{"k":{"col1":1,"col2":2},"v1":{"col1":1,"col2":2},"v2":{"col1":1,"col2":2}}} diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out new file mode 100644 index 0000000000000..efc88e47209a6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -0,0 +1,144 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps +-- !query 1 schema +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +-- !query 1 output +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} + + +-- !query 2 +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(bigint_map1, decimal_map2) bd_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps +-- !query 2 schema +struct,si_map:map,ib_map:map,bd_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +-- !query 2 output +{1:2,3:4} {1:2,7:8} {4:6,8:9} {6:7,9223372036854775808:9223372036854775809} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} + + +-- !query 3 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 + + +-- !query 4 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 + + +-- !query 5 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 + + +-- !query 6 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 + + +-- !query 7 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out new file mode 100644 index 0000000000000..d7d009a64bf84 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out @@ -0,0 +1,93 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px +-- !query 1 schema +struct +-- !query 1 output +1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 +2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 +3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 +4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4 +5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1 + + +-- !query 2 +select id, regr_count(y,x) over (partition by px) from t1 order by id +-- !query 2 schema +struct +-- !query 2 output +101 4 +102 4 +103 4 +104 4 +201 4 +202 4 +203 4 +204 4 +301 4 +302 4 +303 4 +304 4 +401 4 +402 4 +403 4 +404 4 +501 0 +502 0 +503 0 +504 0 +601 0 +602 0 +603 0 +604 0 +701 0 +702 0 +703 0 +704 0 +800 1 diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index 4815a578b1029..87824ab81cdf7 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -33,8 +33,8 @@ SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1 -- !query 3 schema struct<> -- !query 3 output -java.lang.AssertionError -assertion failed: Incorrect number of children +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function default.myDoubleAvg. Expected: 1; Found: 2; line 1 pos 7 -- !query 4 diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index d123b7fdbe0cf..b023df825d814 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 16 -- !query 0 @@ -105,23 +105,29 @@ struct -- !query 9 -DROP VIEW IF EXISTS t1 +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1 -- !query 9 schema -struct<> +struct,str:string> -- !query 9 output - +{1:2,3:null} 1 +{1:2} str -- !query 10 -DROP VIEW IF EXISTS t2 +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1 -- !query 10 schema -struct<> +struct,str:string> -- !query 10 output - +[1,2,3,null] 1 +[1,2] str -- !query 11 -DROP VIEW IF EXISTS p1 +DROP VIEW IF EXISTS t1 -- !query 11 schema struct<> -- !query 11 output @@ -129,7 +135,7 @@ struct<> -- !query 12 -DROP VIEW IF EXISTS p2 +DROP VIEW IF EXISTS t2 -- !query 12 schema struct<> -- !query 12 output @@ -137,8 +143,24 @@ struct<> -- !query 13 -DROP VIEW IF EXISTS p3 +DROP VIEW IF EXISTS p1 -- !query 13 schema struct<> -- !query 13 output + + +-- !query 14 +DROP VIEW IF EXISTS p2 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP VIEW IF EXISTS p3 +-- !query 15 schema +struct<> +-- !query 15 output + diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata new file mode 100644 index 0000000000000..372180b2096ee --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"04d960cd-d38f-4ce6-b8d0-ebcf84c9dccc"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 new file mode 100644 index 0000000000000..807d7b0063b96 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531292029003,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 new file mode 100644 index 0000000000000..cce541073fb4b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000..193524ffe15b5 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata new file mode 100644 index 0000000000000..d6be7fbffa9b7 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata @@ -0,0 +1 @@ +{"id":"549eeb1a-d762-420c-bb44-3fd6d73a5268"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 new file mode 100644 index 0000000000000..43db49d052894 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531172902041,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 new file mode 100644 index 0000000000000..8cc898e81017f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":10000,"batchTimestampMs":1531172902217,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata new file mode 100644 index 0000000000000..c160d737278e1 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"2f32aca2-1b97-458f-a48f-109328724f09"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 new file mode 100644 index 0000000000000..acdc6e69e975a --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784347136,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 new file mode 100644 index 0000000000000..27353e8724507 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784349160,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000..281b21e960909 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta new file mode 100644 index 0000000000000..b701841d71535 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000..f4fb2520a4ac4 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata new file mode 100644 index 0000000000000..f205857e6876f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata @@ -0,0 +1 @@ +{"id":"73f7f943-0a08-4ffb-a504-9fa88ff7612a"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 new file mode 100644 index 0000000000000..8fa80bedc2285 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531991874513,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 new file mode 100644 index 0000000000000..2248a58fea006 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531991878604,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta new file mode 100644 index 0000000000000..171aa58a06e21 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/2.delta new file mode 100644 index 0000000000000..cfb3a481deb59 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/test-data/comments-whitespaces.csv b/sql/core/src/test/resources/test-data/comments-whitespaces.csv new file mode 100644 index 0000000000000..2737978f83a5e --- /dev/null +++ b/sql/core/src/test/resources/test-data/comments-whitespaces.csv @@ -0,0 +1,8 @@ +# The file contains comments, whitespaces and empty lines +colA +# empty line + +# the line with a few whitespaces + +# int value with leading and trailing whitespaces + "a" diff --git a/sql/core/src/test/resources/test-data/parquet-1217.parquet b/sql/core/src/test/resources/test-data/parquet-1217.parquet new file mode 100644 index 0000000000000..eb2dc4f799070 Binary files /dev/null and b/sql/core/src/test/resources/test-data/parquet-1217.parquet differ diff --git a/sql/core/src/test/resources/test-data/utf16LE.json b/sql/core/src/test/resources/test-data/utf16LE.json new file mode 100644 index 0000000000000..ce4117fd299df Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf16LE.json differ diff --git a/sql/core/src/test/resources/test-data/utf16WithBOM.json b/sql/core/src/test/resources/test-data/utf16WithBOM.json new file mode 100644 index 0000000000000..cf4d29328b860 Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf16WithBOM.json differ diff --git a/sql/core/src/test/resources/test-data/utf32BEWithBOM.json b/sql/core/src/test/resources/test-data/utf32BEWithBOM.json new file mode 100644 index 0000000000000..6c7733c577872 Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf32BEWithBOM.json differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 137c5bea2abb9..d635912cf7205 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -279,4 +280,16 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } } + + test("SPARK-24013: unneeded compress can cause performance issues with sorted input") { + val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY) + var compressCounts = 0 + (1 to 10000000).foreach { i => + buffer.add(i) + if (buffer.isCompressed) compressCounts += 1 + } + assert(compressCounts > 0) + buffer.quantileSummaries + assert(buffer.isCompressed) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index e51aad021fcbf..d95794d624033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -54,7 +54,7 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with B plan foreach { case s: WholeStageCodegenExec => codegenSubtrees += s - case s => s + case _ => } codegenSubtrees.toSeq.foreach { subtree => val code = subtree.doCodeGen()._2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 669e5f2bf4e65..60c73df88896b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,12 +22,12 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.CleanerListener +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -52,7 +52,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head @@ -78,29 +78,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { plan.collect { case InMemoryTableScanExec(_, _, relation) => - getNumInMemoryTablesRecursively(relation.child) + 1 + getNumInMemoryTablesRecursively(relation.cachedPlan) + 1 }.sum } - test("withColumn doesn't invalidate cached dataframe") { - var evalCount = 0 - val myUDF = udf((x: String) => { evalCount += 1; "result" }) - val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) - df.cache() - - df.collect() - assert(evalCount === 1) - - df.collect() - assert(evalCount === 1) - - val df2 = df.withColumn("newColumn", lit(1)) - df2.collect() - - // We should not reevaluate the cached dataframe - assert(evalCount === 1) - } - test("cache temp table") { withTempView("tempTable") { testData.select('key).createOrReplaceTempView("tempTable") @@ -200,7 +181,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { spark.table("testData").queryExecution.withCachedData.collect { - case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r + case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r }.size } @@ -367,12 +348,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 @@ -794,4 +775,94 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } } + + private def checkIfNoJobTriggered[T](f: => T): T = { + var numJobTrigered = 0 + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numJobTrigered += 1 + } + } + sparkContext.addSparkListener(jobListener) + try { + val result = f + sparkContext.listenerBus.waitUntilEmpty(10000L) + assert(numJobTrigered === 0) + result + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + + test("SPARK-23880 table cache should be lazy and don't trigger any jobs") { + val cachedData = checkIfNoJobTriggered { + spark.range(1002).filter('id > 1000).orderBy('id.desc).cache() + } + assert(cachedData.collect === Seq(1001)) + } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t1") + assert(!spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withView("t1") { + withTempView("t2") { + sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1") + + sql("CACHE TABLE t1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(!spark.catalog.isCached("t2")) + } + } + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withTempView("t1", "t2") { + sql("CACHE TABLE t") + sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t") + assert(!spark.catalog.isCached("t")) + assert(!spark.catalog.isCached("t1")) + assert(!spark.catalog.isCached("t2")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 7c45be21961d3..2182bd7eadd63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.util.Locale + +import scala.collection.JavaConverters._ + import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ @@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isin(1.toShort, "2")), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin("3", 2.toLong)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, "1")), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - intercept[AnalysisException] { + val e = intercept[AnalysisException] { df2.filter($"a".isin($"b")) } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Scala Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Java Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b").asJava)) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } test("&&") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 949505e449fd7..276496be3d62c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,9 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id.desc) + // Range has range partitioning in its output now. To have a range shuffle, we + // need to run a repartition first. + val data = spark.range(0, n, 1, 1).repartition(10).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7776e36702ad..85b3ca11383f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import scala.util.Random -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.scalatest.Matchers.the + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -36,6 +36,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ + val absTol = 1e-8 + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -416,7 +418,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("moments") { - val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) @@ -556,11 +557,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") - val limit2Df = df.limit(2) - checkAnswer( - limit2Df.groupBy("id").count().select($"id"), - limit2Df.select($"id")) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) + } } test("SPARK-17237 remove backticks in a pivot result schema") { @@ -686,4 +689,44 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21896: Window functions inside aggregate functions") { + def checkWindowError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("not allowed to use a window function")) + } + + checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a))))) + checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a))))) + checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b))))) + checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a))))) + checkWindowError( + testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3)) + checkAnswer( + testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3), + Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil) + + checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) + checkAnswer( + sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"), + Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) + } + + test("SPARK-24788: RelationalGroupedDataset.toString with unresolved exprs should not fail") { + // Checks if these raise no exception + assert(testData.groupBy('key).toString.contains( + "[grouping expressions: [key], value: [key: int, value: string], type: GroupBy]")) + assert(testData.groupBy(col("key")).toString.contains( + "[grouping expressions: [key], value: [key: int, value: string], type: GroupBy]")) + assert(testData.groupBy(current_date()).toString.contains( + "grouping expressions: [current_date(None)], value: [key: int, value: string], " + + "type: GroupBy]")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25e5cd60dd236..156e54300e38b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.util.TimeZone import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -62,6 +65,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(row.getMap[Int, String](0) === Map(2 -> "a")) } + test("map with arrays") { + val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v") + val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) + val row = df1.select(map_from_arrays($"k", $"v")).first() + assert(row.schema(0).dataType === expectedType) + assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b")) + checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) + + val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") + checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) + + val df3 = Seq((null, null)).toDF("k", "v") + checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null))) + + val df4 = Seq((1, "a")).toDF("k", "v") + intercept[AnalysisException] { + df4.select(map_from_arrays($"k", $"v")) + } + + val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") + intercept[RuntimeException] { + df5.select(map_from_arrays($"k", $"v")).collect + } + + val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") + intercept[RuntimeException] { + df6.select(map_from_arrays($"k", $"v")).collect + } + } + test("struct with column name") { val df = Seq((1, "str")).toDF("a", "b") val row = df.select(struct("a", "b")).first() @@ -276,7 +309,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("sort_array function") { + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), @@ -286,28 +319,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.select(sort_array($"a", false), sort_array($"b", false)), Seq( Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a)", "sort_array(b)"), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), Seq( Row(Seq(1, 2, 3), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) @@ -324,40 +357,137 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) + + checkAnswer( + df.select(array_sort($"a"), array_sort($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_sort(a)", "array_sort(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + + checkAnswer( + df2.selectExpr("array_sort(a)"), + Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) + ) + + assert(intercept[AnalysisException] { + df3.selectExpr("array_sort(a)").collect() + }.getMessage().contains("only supports array input")) } - test("array size function") { + def testSizeOfArray(sizeOfNull: Any): Unit = { val df = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "y"), (Seq[Int](1, 2, 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) + + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("cardinality(a)"), Seq(Row(2L), Row(0L), Row(3L), Row(sizeOfNull))) } - test("map size function") { + test("array size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfArray(sizeOfNull = -1) + } + } + + test("array size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfArray(sizeOfNull = null) + } + } + + test("dataframe arrays_zip function") { + val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2") + val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3") + val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2") + val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2") + val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4") + val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null))) + .toDF("v1", "v2", "v3", "v4") + val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2") + val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2") + + val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6))) + checkAnswer(df1.select(arrays_zip($"val1", $"val2")), expectedValue1) + checkAnswer(df1.selectExpr("arrays_zip(val1, val2)"), expectedValue1) + + val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11))) + checkAnswer(df2.select(arrays_zip($"val1", $"val2", $"val3")), expectedValue2) + checkAnswer(df2.selectExpr("arrays_zip(val1, val2, val3)"), expectedValue2) + + val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6))) + checkAnswer(df3.select(arrays_zip($"val1", $"val2")), expectedValue3) + checkAnswer(df3.selectExpr("arrays_zip(val1, val2)"), expectedValue3) + + val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null))) + checkAnswer(df4.select(arrays_zip($"val1", $"val2")), expectedValue4) + checkAnswer(df4.selectExpr("arrays_zip(val1, val2)"), expectedValue4) + + val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null))) + checkAnswer(df5.select(arrays_zip($"val1", $"val2", $"val3", $"val4")), expectedValue5) + checkAnswer(df5.selectExpr("arrays_zip(val1, val2, val3, val4)"), expectedValue5) + + val expectedValue6 = Row(Seq( + Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null))) + checkAnswer(df6.select(arrays_zip($"v1", $"v2", $"v3", $"v4")), expectedValue6) + checkAnswer(df6.selectExpr("arrays_zip(v1, v2, v3, v4)"), expectedValue6) + + val expectedValue7 = Row(Seq( + Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2))) + checkAnswer(df7.select(arrays_zip($"v1", $"v2")), expectedValue7) + checkAnswer(df7.selectExpr("arrays_zip(v1, v2)"), expectedValue7) + + val expectedValue8 = Row(Seq( + Row(Array[Byte](1.toByte, 5.toByte), null))) + checkAnswer(df8.select(arrays_zip($"v1", $"v2")), expectedValue8) + checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) + } + + test("SPARK-24633: arrays_zip splits input processing correctly") { + Seq("true", "false").foreach { wholestageCodegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) { + val df = spark.range(1) + val exprs = (0 to 5).map(x => array($"id" + lit(x))) + checkAnswer(df.select(arrays_zip(exprs: _*)), + Row(Seq(Row(0, 1, 2, 3, 4, 5)))) + } + } + } + + def testSizeOfMap(sizeOfNull: Any): Unit = { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), (Map[Int, Int](), "y"), (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) + + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + } + + test("map size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfMap(sizeOfNull = -1: Int) + } + } + + test("map size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfMap(sizeOfNull = null) + } } test("map_keys/map_values function") { @@ -376,11 +506,195 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_entries") { + // Primitive-type elements + val idf = Seq( + Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), + Map[Int, Int](), + null + ).toDF("m") + val iExpected = Seq( + Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))), + Row(Seq.empty), + Row(null) + ) + + def testPrimitiveType(): Unit = { + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.selectExpr("map_entries(map(1, null, 2, null))"), + Seq.fill(iExpected.length)(Row(Seq(Row(1, null), Row(2, null))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + idf.cache() + testPrimitiveType() + + // Non-primitive-type elements + val sdf = Seq( + Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"), + Map[String, String]("a" -> null, "b" -> null), + Map[String, String](), + null + ).toDF("m") + val sExpected = Seq( + Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))), + Row(Seq(Row("a", null), Row("b", null))), + Row(Seq.empty), + Row(null) + ) + + def testNonPrimitiveType(): Unit = { + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() + } + + test("map_concat function") { + val df1 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), + (null, Map[Int, Int](3 -> 300, 4 -> 400)) + ).toDF("map1", "map2") + + val expected1a = Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a) + + val expected1b = Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) + checkAnswer(df1.select(map_concat('map1)), expected1b) + + val df2 = Seq( + ( + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), + Map[String, Int]("3" -> 300, "4" -> 400) + ) + ).toDF("map1", "map2") + + val expected2 = Seq(Row(Map())) + + checkAnswer(df2.selectExpr("map_concat()"), expected2) + checkAnswer(df2.select(map_concat()), expected2) + + val df3 = { + val schema = StructType( + StructField("map1", MapType(StringType, IntegerType, true), false) :: + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil + ) + val data = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val expected3 = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) + ) + + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected3) + + val expectedMessage1 = "input to function map_concat should all be the same type" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }.getMessage().contains(expectedMessage1)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, 'map2)).collect() + }.getMessage().contains(expectedMessage1)) + + val expectedMessage2 = "input to function map_concat should all be of type map" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }.getMessage().contains(expectedMessage2)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, lit(12))).collect() + }.getMessage().contains(expectedMessage2)) + } + + test("map_from_entries function") { + // Test cases with primitive-type keys and values + val idf = Seq( + Seq((1, 10), (2, 20), (3, 10)), + Seq((1, 10), null, (2, 20)), + Seq.empty, + null + ).toDF("a") + val iExpected = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), + Row(null), + Row(Map.empty), + Row(null)) + + def testPrimitiveType(): Unit = { + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq.fill(iExpected.length)(Row(Map(1 -> null, 2 -> null)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + idf.cache() + testPrimitiveType() + + // Test cases with non-primitive-type keys and values + val sdf = Seq( + Seq(("a", "aa"), ("b", "bb"), ("c", "aa")), + Seq(("a", "aa"), null, ("b", "bb")), + Seq(("a", null), ("b", null)), + Seq.empty, + null + ).toDF("a") + val sExpected = Seq( + Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(null), + Row(Map("a" -> null, "b" -> null)), + Row(Map.empty), + Row(null)) + + def testNonPrimitiveType(): Unit = { + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() + } + test("array contains function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") // Simple test cases checkAnswer( @@ -391,6 +705,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) + checkAnswer( + df.select(array_contains(df("a"), df("c"))), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, c)"), + Seq(Row(true), Row(false)) + ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -413,6 +735,91 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("arrays_overlap function") { + val df = Seq( + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) + ).toDF("a", "b") + + val answer = Seq(Row(false), Row(null), Row(true)) + + checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) + checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + + checkAnswer( + Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"), + Row(true)) + + intercept[AnalysisException] { + sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(null, null)") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(map(1, 2), map(3, 4))") + } + } + + test("slice function") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5) + ).toDF("x") + + val answer = Seq(Row(Seq(2, 3)), Row(Seq(5))) + + checkAnswer(df.select(slice(df("x"), 2, 2)), answer) + checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer) + + val answerNegative = Seq(Row(Seq(3)), Row(Seq(5))) + checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative) + checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative) + } + + test("array_join function") { + val df = Seq( + (Seq[String]("a", "b"), ","), + (Seq[String]("a", null, "b"), ","), + (Seq.empty[String], ",") + ).toDF("x", "delimiter") + + checkAnswer( + df.select(array_join(df("x"), ";")), + Seq(Row("a;b"), Row("a;b"), Row("")) + ) + checkAnswer( + df.select(array_join(df("x"), ";", "NULL")), + Seq(Row("a;b"), Row("a;NULL;b"), Row("")) + ) + checkAnswer( + df.selectExpr("array_join(x, delimiter)"), + Seq(Row("a,b"), Row("a,b"), Row(""))) + checkAnswer( + df.selectExpr("array_join(x, delimiter, 'NULL')"), + Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + + val idf = Seq(Seq(1, 2, 3)).toDF("x") + + checkAnswer( + idf.select(array_join(idf("x"), ", ")), + Seq(Row("1, 2, 3")) + ) + checkAnswer( + idf.selectExpr("array_join(x, ', ')"), + Seq(Row("1, 2, 3")) + ) + intercept[AnalysisException] { + idf.selectExpr("array_join(x, 1)") + } + intercept[AnalysisException] { + idf.selectExpr("array_join(x, ', ', 1)") + } + } + test("array_min function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), @@ -441,63 +848,129 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } - test("reverse function") { - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on - - // String test cases - val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + test("sequence") { + checkAnswer(Seq((-2, 2)).toDF().select(sequence('_1, '_2)), Seq(Row(Array(-2, -1, 0, 1, 2)))) + checkAnswer(Seq((7, 2, -2)).toDF().select(sequence('_1, '_2, '_3)), Seq(Row(Array(7, 5, 3)))) checkAnswer( - oneRowDF.select(reverse('s)), - Seq(Row("krapS")) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(s)"), - Seq(Row("krapS")) - ) - checkAnswer( - oneRowDF.select(reverse('i)), - Seq(Row("5123")) - ) + spark.sql("select sequence(" + + " cast('2018-01-01 00:00:00' as timestamp)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-03-01' as date)" + + ", interval 1 month)"), + Seq(Row(Array( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))))) + } + + // test type coercion checkAnswer( - oneRowDF.selectExpr("reverse(i)"), - Seq(Row("5123")) - ) + Seq((1.toByte, 3L, 1)).toDF().select(sequence('_1, '_2, '_3)), + Seq(Row(Array(1L, 2L, 3L)))) + checkAnswer( - oneRowDF.selectExpr("reverse(null)"), - Seq(Row(null)) - ) + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) - // Array test cases (primitive-type elements) - val idf = Seq( + // test invalid data types + intercept[AnalysisException] { + Seq((true, false)).toDF().selectExpr("sequence(_1, _2)") + } + intercept[AnalysisException] { + Seq((true, false, 42)).toDF().selectExpr("sequence(_1, _2, _3)") + } + intercept[AnalysisException] { + Seq((1, 2, 0.5)).toDF().selectExpr("sequence(_1, _2, _3)") + } + } + + test("reverse function - string") { + val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + def testString(): Unit = { + checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS"))) + checkAnswer(oneRowDF.selectExpr("reverse(s)"), Seq(Row("krapS"))) + checkAnswer(oneRowDF.select(reverse('i)), Seq(Row("5123"))) + checkAnswer(oneRowDF.selectExpr("reverse(i)"), Seq(Row("5123"))) + checkAnswer(oneRowDF.selectExpr("reverse(null)"), Seq(Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + oneRowDF.cache() + testString() + } + + test("reverse function - array for primitive type not containing null") { + val idfNotContainsNull = Seq( Seq(1, 9, 8, 7), Seq(5, 8, 9, 7, 2), Seq.empty, null ).toDF("i") - checkAnswer( - idf.select(reverse('i)), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - idf.filter(dummyFilter('i)).select(reverse('i)), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - idf.selectExpr("reverse(i)"), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), - Seq(Row(Seq(null, 2, null, 1))) - ) - checkAnswer( - oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), - Seq(Row(Seq(null, 2, null, 1))) - ) + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer( + idfNotContainsNull.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idfNotContainsNull.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfNotContainsNull.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("reverse function - array for primitive type containing null") { + val idfContainsNull = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(null, 5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer( + idfContainsNull.select(reverse('i)), + Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idfContainsNull.selectExpr("reverse(i)"), + Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfContainsNull.cache() + testArrayOfPrimitiveTypeContainsNull() + } - // Array test cases (non-primitive-type elements) + test("reverse function - array for non-primitive type") { val sdf = Seq( Seq("c", "a", "b"), Seq("b", null, "c", null), @@ -505,41 +978,45 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("s") - checkAnswer( - sdf.select(reverse('s)), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - sdf.filter(dummyFilter('s)).select(reverse('s)), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - sdf.selectExpr("reverse(s)"), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), - Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) - ) - checkAnswer( - oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), - Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) - ) + def testArrayOfNonPrimitiveType(): Unit = { + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq.fill(sdf.count().toInt)(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + } - // Error test cases - intercept[AnalysisException] { - oneRowDF.selectExpr("reverse(struct(1, 'a'))") + // Test with local relation, the Project will be evaluated without codegen + testArrayOfNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testArrayOfNonPrimitiveType() + } + + test("reverse function - data type mismatch") { + val ex1 = intercept[AnalysisException] { + sql("select reverse(struct(1, 'a'))") } - intercept[AnalysisException] { - oneRowDF.selectExpr("reverse(map(1, 'a'))") + assert(ex1.getMessage.contains("data type mismatch")) + + val ex2 = intercept[AnalysisException] { + sql("select reverse(map(1, 'a'))") } + assert(ex2.getMessage.contains("data type mismatch")) } test("array position function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") checkAnswer( df.select(array_position(df("a"), 1)), @@ -549,10 +1026,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) - checkAnswer( - df.select(array_position(df("a"), null)), - Seq(Row(null), Row(null)) + df.selectExpr("array_position(a, c)"), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.select(array_position(df("a"), df("c"))), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.select(array_position(df("a"), null)), + Seq(Row(null), Row(null)) ) checkAnswer( df.selectExpr("array_position(a, null)"), @@ -567,14 +1051,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } test("element_at function") { val df = Seq( - (Seq[String]("1", "2", "3")), - (Seq[String](null, "")), - (Seq[String]()) - ).toDF("a") + (Seq[String]("1", "2", "3"), 1), + (Seq[String](null, ""), -1), + (Seq[String](), 2) + ).toDF("a", "b") intercept[Exception] { checkAnswer( @@ -592,6 +1081,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), 4)), Seq(Row(null), Row(null), Row(null)) ) + checkAnswer( + df.select(element_at(df("a"), df("b"))), + Seq(Row("1"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, b)"), + Seq(Row("1"), Row(""), Row(null)) + ) checkAnswer( df.select(element_at(df("a"), 1)), @@ -615,145 +1112,1580 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("element_at(a, -1)"), Seq(Row("3"), Row(""), Row(null)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") + } + assert(e.message.contains( + "argument 1 requires (array or map) type, however, '`_1`' is of string type")) + } + + test("array_union functions") { + val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1, 2, 3, 4)) + checkAnswer(df1.select(array_union($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_union(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b") + val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkAnswer(df2.select(array_union($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_union(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L, 2L, 3L, 4L)) + checkAnswer(df3.select(array_union($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_union(a, b)"), ans3) + + val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkAnswer(df4.select(array_union($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_union(a, b)"), ans4) + + val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("b", "a", "c", null, "g")) + checkAnswer(df5.select(array_union($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) + + val df6 = Seq((null, Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df6.select(array_union($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df6.selectExpr("array_union(a, b)") + }.getMessage.contains("data type mismatch")) + + val df7 = Seq((null, null)).toDF("a", "b") + assert(intercept[AnalysisException] { + df7.select(array_union($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df7.selectExpr("array_union(a, b)") + }.getMessage.contains("data type mismatch")) + + val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df8.select(array_union($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df8.selectExpr("array_union(a, b)") + }.getMessage.contains("data type mismatch")) } test("concat function - arrays") { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null val df = Seq( - (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on - // Simple test cases - checkAnswer( - df.selectExpr("array(1, 2, 3L)"), - Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) - ) + def simpleTest(): Unit = { + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + } - checkAnswer ( - df.select(concat($"i1", $"s1")), - Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) - ) - checkAnswer( - df.select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.selectExpr("concat(array(1, null), i2, i3)"), - Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + // Test with local relation, the Project will be evaluated without codegen + simpleTest() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + simpleTest() + + // Null test cases + def nullTest(): Unit = { + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + df.unpersist() + nullTest() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + nullTest() + + // Type error test cases + intercept[AnalysisException] { + df.selectExpr("concat(i1, i2, null)") + } + + intercept[AnalysisException] { + df.selectExpr("concat(i1, array(i1, i2))") + } + + val e = intercept[AnalysisException] { + df.selectExpr("concat(map(1, 2), map(3, 4))") + } + assert(e.getMessage.contains("string, binary or array")) + } + + test("flatten function") { + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq(1), null)), + (Seq(null, Seq(1))), + (Seq(null, null)) + ).toDF("i") + + val intDFResult = Seq( + Row(Seq(1, 2, 3, 4, 5, 6)), + Row(Seq(1, 2)), + Row(Seq(1)), + Row(Seq(1)), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq("a"), null)), + (Seq(null, Seq("a"))), + (Seq(null, null)) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a")), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() + + // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } + + test("array_repeat function") { + val strDF = Seq( + ("hi", 2), + (null, 2) + ).toDF("a", "b") + + val strDFTwiceResult = Seq( + Row(Seq("hi", "hi")), + Row(Seq(null, null)) ) - checkAnswer( - df.select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + + def testString(): Unit = { + checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val intDF = { + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", IntegerType))) + val data = Seq( + Row(3, 2), + Row(null, 2) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val intDFTwiceResult = Seq( + Row(Seq(3, 3)), + Row(Seq(null, null)) ) + + def testInt(): Unit = { + checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() + + val nullCountDF = { + val schema = StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))) + val data = Seq( + Row("hi", null), + Row(null, null) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + def testNull(): Unit = { + checkAnswer( + nullCountDF.select(array_repeat($"a", $"b")), + Seq(Row(null), Row(null)) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + testNull() + // Test with cached relation, the Project will be evaluated with codegen + nullCountDF.cache() + testNull() + + // Error test cases + val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b") + + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", $"b")) + } + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", lit("1"))) + } + intercept[AnalysisException] { + invalidTypeDF.selectExpr("array_repeat(a, 1.0)") + } + + } + + test("array remove") { + val df = Seq( + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2) + ).toDF("a", "b", "c", "d") checkAnswer( - df.selectExpr("concat(s1, s2, s3)"), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) ) + checkAnswer( - df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + df.select(array_remove($"a", $"d")), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) ) - // Null test cases checkAnswer( - df.select(concat($"i1", $"in")), - Seq(Row(null), Row(null)) + df.selectExpr("array_remove(a, d)"), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) ) + checkAnswer( - df.select(concat($"in", $"i1")), - Seq(Row(null), Row(null)) + df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", + "array_remove(c, \"\")"), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + } + + test("array_distinct functions") { + val df = Seq( + (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")), + (Array.empty[Int], Array.empty[String]), + (null, null) + ).toDF("a", "b") checkAnswer( - df.select(concat($"s1", $"sn")), - Seq(Row(null), Row(null)) + df.select(array_distinct($"a"), array_distinct($"b")), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) ) checkAnswer( - df.select(concat($"sn", $"s1")), - Seq(Row(null), Row(null)) + df.selectExpr("array_distinct(a)", "array_distinct(b)"), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) ) + } - // Type error test cases - intercept[AnalysisException] { - df.selectExpr("concat(i1, i2, null)") - } + // Shuffle expressions should produce same results at retries in the same DataFrame. + private def checkShuffleResult(df: DataFrame): Unit = { + checkAnswer(df, df.collect()) + } - intercept[AnalysisException] { - df.selectExpr("concat(i1, array(i1, i2))") + test("shuffle function - array for primitive type not containing null") { + val idfNotContainsNull = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkShuffleResult(idfNotContainsNull.select(shuffle('i))) + checkShuffleResult(idfNotContainsNull.selectExpr("shuffle(i)")) } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfNotContainsNull.cache() + testArrayOfPrimitiveTypeNotContainsNull() } - private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { - import DataFrameFunctionsSuite.CodegenFallbackExpr - for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { - val c = if (codegenFallback) { - Column(CodegenFallbackExpr(v.expr)) - } else { - v - } - withSQLConf( - (SQLConf.CODEGEN_FALLBACK.key, codegenFallback.toString), - (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { - val df = spark.range(0, 4, 1, 4).withColumn("c", c) - val rows = df.collect() - val rowsAfterCoalesce = df.coalesce(2).collect() - assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " + - s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + test("shuffle function - array for primitive type containing null") { + val idfContainsNull = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(null, 5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") - val df1 = spark.range(0, 2, 1, 2).withColumn("c", c) - val rows1 = df1.collect() - val df2 = spark.range(2, 4, 1, 2).withColumn("c", c) - val rows2 = df2.collect() - val rowsAfterUnion = df1.union(df2).collect() - assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " + - s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") - } + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkShuffleResult(idfContainsNull.select(shuffle('i))) + checkShuffleResult(idfContainsNull.selectExpr("shuffle(i)")) } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfContainsNull.cache() + testArrayOfPrimitiveTypeContainsNull() } - test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + - "coalesce or union") { - Seq( - monotonically_increasing_id(), spark_partition_id(), - rand(Random.nextLong()), randn(Random.nextLong()) - ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) + test("shuffle function - array for non-primitive type") { + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkShuffleResult(sdf.select(shuffle('s))) + checkShuffleResult(sdf.selectExpr("shuffle(s)")) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() } - test("SPARK-21281 use string types by default if array and map have no argument") { - val ds = spark.range(1) - var expectedSchema = new StructType() - .add("x", ArrayType(StringType, containsNull = false), nullable = false) - assert(ds.select(array().as("x")).schema == expectedSchema) - expectedSchema = new StructType() - .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) - assert(ds.select(map().as("x")).schema == expectedSchema) + test("array_except functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1)) + checkAnswer(df1.select(array_except($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_except(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(1, 5)) + checkAnswer(df2.select(array_except($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_except(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L)) + checkAnswer(df3.select(array_except($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_except(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 5L)) + checkAnswer(df4.select(array_except($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_except(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("c", "f")) + checkAnswer(df5.select(array_except($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_except(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_except(a, b)") + } + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_except(a, b)") + } + val df8 = Seq((Array("a"), null)).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_except(a, b)") + } + val df9 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df9.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df9.selectExpr("array_except(a, b)") + } + + val df10 = Seq( + (Array[Integer](1, 2), Array[Integer](2)), + (Array[Integer](1, 2), Array[Integer](1, null)), + (Array[Integer](1, null, 3), Array[Integer](1, 2)), + (Array[Integer](1, null), Array[Integer](2, null)) + ).toDF("a", "b") + val result10 = df10.select(array_except($"a", $"b")) + val expectedType10 = ArrayType(IntegerType, containsNull = true) + assert(result10.first.schema(0).dataType === expectedType10) } - test("SPARK-21281 fails if functions have no argument") { - val df = Seq(1).toDF("a") + test("array_intersect functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(2, 4)) + checkAnswer(df1.select(array_intersect($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_intersect(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(2, null, 4)) + checkAnswer(df2.select(array_intersect($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_intersect(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(2L, 4L)) + checkAnswer(df3.select(array_intersect($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_intersect(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(2L, null, 4L)) + checkAnswer(df4.select(array_intersect($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_intersect(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", "a", null, "g"))).toDF("a", "b") + val ans5 = Row(Seq(null, "a")) + checkAnswer(df5.select(array_intersect($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_intersect(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + assert(intercept[AnalysisException] { + df6.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df6.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) - val funcsMustHaveAtLeastOneArg = - ("coalesce", (df: DataFrame) => df.select(coalesce())) :: - ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: - ("named_struct", (df: DataFrame) => df.select(struct())) :: - ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: - ("hash", (df: DataFrame) => df.select(hash())) :: - ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil - funcsMustHaveAtLeastOneArg.foreach { case (name, func) => - val errMsg = intercept[AnalysisException] { func(df) }.getMessage - assert(errMsg.contains(s"input to function $name requires at least one argument")) - } + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df7.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df7.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) - val funcsMustHaveAtLeastTwoArgs = - ("greatest", (df: DataFrame) => df.select(greatest())) :: + val df8 = Seq((null, Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df8.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df8.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + } + + test("transform function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("transform function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("transform function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("transform(s, x -> concat(x, x))"), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(s, (x, i) -> concat(x, i))"), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("transform function - special cases") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("arg") + + def testSpecialCases(): Unit = { + checkAnswer(df.selectExpr("transform(arg, arg -> arg)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", null, "c", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, arg)"), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, x -> concat(arg, array(x)))"), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testSpecialCases() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testSpecialCases() + } + + test("transform function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("transform(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("transform(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("transform(a, x -> x)") + } + assert(ex3.getMessage.contains("cannot resolve '`a`'")) + } + + test("map_filter") { + val dfInts = Seq( + Map(1 -> 10, 2 -> 20, 3 -> 30), + Map(1 -> -1, 2 -> -2, 3 -> -3), + Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m") + + checkAnswer(dfInts.selectExpr( + "map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"), + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) + + val dfComplex = Seq( + Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), + Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") + + checkAnswer(dfComplex.selectExpr( + "map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"), + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) + + // Invalid use cases + val df = Seq( + (Map(1 -> "a"), 1), + (Map.empty[Int, String], 2), + (null, 3) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, x -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_filter(i, (k, v) -> k > v)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_filter(a, (k, v) -> k > v)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + + test("filter function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("filter function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("filter function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("filter function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("filter(s, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("filter(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("filter(s, x -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("filter(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + + test("exists function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 9, 7), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("exists function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, null, 9, 7, null), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("exists function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("exists(s, x -> x is null)"), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("exists function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("exists(s, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("exists(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("exists(s, x -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("exists(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + + test("aggregate function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("aggregate function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10)"), + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("aggregate function - array for non-primitive type") { + val df = Seq( + (Seq("c", "a", "b"), "a"), + (Seq("b", null, "c", null), "b"), + (Seq.empty, "c"), + (null, "d") + ).toDF("ss", "s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> coalesce(acc , ''))"), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("aggregate function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("aggregate(i, 0, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, 0, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("aggregate(a, 0, (acc, x) -> x)") + } + assert(ex5.getMessage.contains("cannot resolve '`a`'")) + } + + test("map_zip_with function - map of primitive types") { + val df = Seq( + (Map(8 -> 6L, 3 -> 5L, 6 -> 2L), Map[Integer, Integer]((6, 4), (8, 2), (3, 2))), + (Map(10 -> 6L, 8 -> 3L), Map[Integer, Integer]((8, 4), (4, null))), + (Map.empty[Int, Long], Map[Integer, Integer]((5, 1))), + (Map(5 -> 1L), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) + } + + test("map_zip_with function - map of non-primitive types") { + val df = Seq( + (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")), + (Map("b" -> "a", "c" -> "d"), Map("c" -> "a", "b" -> null, "d" -> "k")), + (Map("a" -> "d"), Map.empty[String, String]), + (Map("a" -> "d"), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) + } + + test("map_zip_with function - invalid") { + val df = Seq( + (Map(1 -> 2), Map(1 -> "a"), Map("a" -> "b"), Map(Map(1 -> 2) -> 2), 1) + ).toDF("mii", "mis", "mss", "mmi", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mii, mis, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") + } + assert(ex2.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with compatible key types")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") + } + assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") + } + assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") + } + assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) + } + + test("transform keys function - primitive data types") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) + ).toDF("j") + + val dfExample3 = Seq( + Map[Int, Boolean](25 -> true, 26 -> false) + ).toDF("x") + + val dfExample4 = Seq( + Map[Array[Int], Boolean](Array(1, 2) -> false) + ).toDF("y") + + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), + Seq(Row(Map(true -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), + Seq(Row(Map(false -> false)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform keys function - Invalid lambda functions and exceptions") { + val dfExample1 = Seq( + Map[String, String]("a" -> null) + ).toDF("i") + + val dfExample2 = Seq( + Seq(1, 2, 3, 4) + ).toDF("j") + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains( + "The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[RuntimeException] { + dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() + } + assert(ex3.getMessage.contains("Cannot use null as map key!")) + + val ex4 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") + } + assert(ex4.getMessage.contains( + "data type mismatch: argument 1 requires map type")) + } + + test("transform values function - test primitive data types") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Boolean, String](false -> "abc", true -> "def") + ).toDF("x") + + val dfExample3 = Seq( + Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3) + ).toDF("y") + + val dfExample4 = Seq( + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) + ).toDF("z") + + val dfExample5 = Seq( + Map[Int, Array[Int]](1 -> Array(1, 2)) + ).toDF("c") + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), + Seq(Row(Map(false -> "false", true -> "def")))) + + checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), + Seq(Row(Map(false -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + + checkAnswer(dfExample3.selectExpr( + "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + + checkAnswer( + dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + + checkAnswer( + dfExample4.selectExpr( + "transform_values(" + + "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + + "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), + Seq(Row(Map(1 -> 3)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + dfExample5.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform values function - test empty") { + val dfExample1 = Seq( + Map.empty[Integer, Integer] + ).toDF("i") + + val dfExample2 = Seq( + Map.empty[BigInt, String] + ).toDF("j") + + def testEmpty(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> NULL)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> v)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 0)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 'value')"), + Seq(Row(Map.empty[Integer, String]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> true)"), + Seq(Row(Map.empty[Integer, Boolean]))) + + checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), + Seq(Row(Map.empty[BigInt, BigInt]))) + } + + testEmpty() + dfExample1.cache() + dfExample2.cache() + testEmpty() + } + + test("transform values function - test null values") { + val dfExample1 = Seq( + Map[Int, Integer](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4) + ).toDF("a") + + val dfExample2 = Seq( + Map[Int, String](1 -> "a", 2 -> "b", 3 -> null) + ).toDF("b") + + def testNullValue(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(a, (k, v) -> null)"), + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + } + + testNullValue() + dfExample1.cache() + dfExample2.cache() + testNullValue() + } + + test("transform values function - test invalid functions") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[String, String]("a" -> "b") + ).toDF("j") + + val dfExample3 = Seq( + Seq(1, 2, 3, 4) + ).toDF("x") + + def testInvalidLambdaFunctions(): Unit = { + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_values(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[AnalysisException] { + dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)") + } + assert(ex3.getMessage.contains( + "data type mismatch: argument 1 requires map type")) + } + + testInvalidLambdaFunctions() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + testInvalidLambdaFunctions() + } + + test("arrays zip_with function - for primitive types") { + val df1 = Seq[(Seq[Integer], Seq[Integer])]( + (Seq(9001, 9002, 9003), Seq(4, 5, 6)), + (Seq(1, 2), Seq(3, 4)), + (Seq.empty, Seq.empty), + (null, null) + ).toDF("val1", "val2") + val df2 = Seq[(Seq[Integer], Seq[Long])]( + (Seq(1, null, 3), Seq(1L, 2L)), + (Seq(1, 2, 3), Seq(4L, 11L)) + ).toDF("val1", "val2") + val expectedValue1 = Seq( + Row(Seq(9005, 9007, 9009)), + Row(Seq(4, 6)), + Row(Seq.empty), + Row(null)) + checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) + val expectedValue2 = Seq( + Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), + Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) + checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2) + } + + test("arrays zip_with function - for non-primitive types") { + val df = Seq( + (Seq("a"), Seq("x", "y", "z")), + (Seq("a", null), Seq("x", "y")), + (Seq.empty[String], Seq.empty[String]), + (Seq("a", "b", "c"), null) + ).toDF("val1", "val2") + val expectedValue1 = Seq( + Row(Seq(Row("x", "a"), Row("y", null), Row("z", null))), + Row(Seq(Row("x", "a"), Row("y", null))), + Row(Seq.empty), + Row(null)) + checkAnswer(df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue1) + } + + test("arrays zip_with function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), Seq("x", "y", "z"), 1), + (Seq("b", null, "c", null), Seq("x"), 2), + (Seq.empty, Seq("x", "z"), 3), + (null, Seq("x", "z"), 4) + ).toDF("a1", "a2", "i") + val ex1 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a2, x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + val ex2 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a2, (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("Invalid number of arguments for function zip_with")) + val ex3 = intercept[AnalysisException] { + df.selectExpr("zip_with(i, a2, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex4 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { + import DataFrameFunctionsSuite.CodegenFallbackExpr + for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { + val c = if (codegenFallback) { + Column(CodegenFallbackExpr(v.expr)) + } else { + v + } + withSQLConf( + (SQLConf.CODEGEN_FALLBACK.key, codegenFallback.toString), + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { + val df = spark.range(0, 4, 1, 4).withColumn("c", c) + val rows = df.collect() + val rowsAfterCoalesce = df.coalesce(2).collect() + assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + + val df1 = spark.range(0, 2, 1, 2).withColumn("c", c) + val rows1 = df1.collect() + val df2 = spark.range(2, 4, 1, 2).withColumn("c", c) + val rows2 = df2.collect() + val rowsAfterUnion = df1.union(df2).collect() + assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + } + } + } + + test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + + "coalesce or union") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) + } + + test("SPARK-21281 use string types by default if array and map have no argument") { + val ds = spark.range(1) + var expectedSchema = new StructType() + .add("x", ArrayType(StringType, containsNull = false), nullable = false) + assert(ds.select(array().as("x")).schema == expectedSchema) + expectedSchema = new StructType() + .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) + assert(ds.select(map().as("x")).schema == expectedSchema) + } + + test("SPARK-21281 fails if functions have no argument") { + val df = Seq(1).toDF("a") + + val funcsMustHaveAtLeastOneArg = + ("coalesce", (df: DataFrame) => df.select(coalesce())) :: + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: + ("named_struct", (df: DataFrame) => df.select(struct())) :: + ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: + ("hash", (df: DataFrame) => df.select(hash())) :: + ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil + funcsMustHaveAtLeastOneArg.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least one argument")) + } + + val funcsMustHaveAtLeastTwoArgs = + ("greatest", (df: DataFrame) => df.select(greatest())) :: ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) :: ("least", (df: DataFrame) => df.select(least())) :: ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil @@ -762,6 +2694,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg.contains(s"input to function $name requires at least two arguments")) } } + + test("SPARK-24734: Fix containsNull of Concat for array type") { + val df = Seq((Seq(1), Seq[Integer](null), Seq("a", "b"))).toDF("k1", "k2", "v") + val ex = intercept[RuntimeException] { + df.select(map_from_arrays(concat($"k1", $"k2"), $"v")).show() + } + assert(ex.getMessage.contains("Cannot use null as map key")) + } } object DataFrameFunctionsSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala index 0dd5bdcba2e4c..7ef8b542c79a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -59,4 +59,14 @@ class DataFrameHintSuite extends AnalysisTest with SharedSQLContext { ) ) } + + test("coalesce and repartition hint") { + check( + df.hint("COALESCE", 10), + UnresolvedHint("COALESCE", Seq(10), df.logicalPlan)) + + check( + df.hint("REPARTITION", 100), + UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0d9eeabb397a1..e6b30f9956daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -196,7 +196,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") // outer -> left - val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" === 3) + val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" >= 3) assert(outerJoin2Left.queryExecution.optimizedPlan.collect { case j @ Join(_, _, LeftOuter, _) => j }.size === 1) checkAnswer( @@ -204,7 +204,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(3, 4, "3", null, null, null) :: Nil) // outer -> right - val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" === 5) + val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" >= 3) assert(outerJoin2Right.queryExecution.optimizedPlan.collect { case j @ Join(_, _, RightOuter, _) => j }.size === 1) checkAnswer( @@ -221,7 +221,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", 1, 3, "1") :: Nil) // right -> inner - val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" === 1) + val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" > 0) assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { case j @ Join(_, _, Inner, _) => j }.size === 1) checkAnswer( @@ -229,7 +229,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", 1, 3, "1") :: Nil) // left -> inner - val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" === 3) + val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" > 0) assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { case j @ Join(_, _, Inner, _) => j }.size === 1) checkAnswer( @@ -287,4 +287,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan } } + + test("SPARK-24385: Resolve ambiguity in self-joins with EqualNullSafe") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(2) + // this throws an exception before the fix + df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 6ca9ee57e8f49..b972b9ef93e5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("pivot courses") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), - Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings")), + expected) } test("pivot year") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)), + expected) } test("pivot courses with multiple aggregations") { + val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy($"year") .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), - Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: - Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year") + .pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings"), avg($"earnings")), + expected) } test("pivot year with string values (cast)") { @@ -67,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { test("pivot courses with no values") { // Note Java comes before dotNet in sorted order + val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), - Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + expected) } test("pivot year with no values") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + expected) } test("pivot max values enforced") { @@ -181,10 +199,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { } test("pivot with datatype not supported by PivotFirst") { + val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil checkAnswer( complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), - Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil - ) + expected) + checkAnswer( + complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)), + expected) } test("pivot with datatype not supported by PivotFirst 2") { @@ -246,4 +267,45 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) } } + + test("SPARK-24722: pivoting nested columns") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase)) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("SPARK-24722: references to multiple columns in the pivot column") { + val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET")) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("SPARK-24722: pivoting by a constant") { + val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil + val df1 = trainingSales + .groupBy($"sales.year") + .pivot(lit(123), Seq(123)) + .agg(sum($"sales.earnings")) + + checkAnswer(df1, expected) + } + + test("SPARK-24722: aggregate as the pivot column") { + val exception = intercept[AnalysisException] { + trainingSales + .groupBy($"sales.year") + .pivot(min($"training"), Seq("Experts")) + .agg(sum($"sales.earnings")) + } + + assert(exception.getMessage.contains("aggregate functions are not allowed")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index a0fd74088ce8b..b0b46640ff317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql -import java.util.concurrent.{CountDownLatch, TimeUnit} - import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -154,53 +152,35 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - // Save and restore the value because SparkContext is shared - val savedInterruptOnCancel = sparkContext - .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) - - try { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") - - for (codegen <- Seq(true, false)) { - // This countdown latch used to make sure with all the stages cancelStage called in listener - val latch = new CountDownLatch(2) - - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - sparkContext.cancelStage(taskStart.stageId) - latch.countDown() - } - } + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) + } + } - sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - val ex = intercept[SparkException] { - sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => - x.synchronized { - x.wait() - } - x - }.toDF("id").agg(sum("id")).collect() - } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") - } + sparkContext.addSparkListener(listener) + for (codegen <- Seq(true, false)) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val ex = intercept[SparkException] { + spark.range(0, 100000000000L, 1, 1) + .toDF("id").agg(sum("id")).collect() } - latch.await(20, TimeUnit.SECONDS) - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } - sparkContext.removeSparkListener(listener) } - } finally { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, - savedInterruptOnCancel) + // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages + sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee7504..6f5c73074313c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -27,6 +27,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} @@ -36,7 +37,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} -import org.apache.spark.sql.test.SQLTestData.TestData2 +import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -629,6 +630,74 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("except all") { + checkAnswer( + lowerCaseData.exceptAll(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) + checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.exceptAll(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.exceptAll(nullInts), + Nil) + + // check that duplicate values are preserved + checkAnswer( + allNulls.exceptAll(allNulls.filter("0 = 1")), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + checkAnswer( + allNulls.exceptAll(allNulls.limit(2)), + Row(null) :: Row(null) :: Nil) + + // check that duplicates are retained. + val df = spark.sparkContext.parallelize( + NullStrings(1, "id1") :: + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") + + checkAnswer( + df.exceptAll(df.filter("0 = 1")), + Row(1, "id1") :: + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").exceptAll(allNulls), + Nil) + + } + + test("exceptAll - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.exceptAll(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.exceptAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.exceptAll(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.exceptAll(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("intersect") { checkAnswer( lowerCaseData.intersect(lowerCaseData), @@ -681,6 +750,60 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) @@ -1044,6 +1167,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select($"*").show(1000) } + test("getRows: truncate = [0, 20]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111111111111111111111")) + assert(df.getRows(10, 0) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111111...")) + assert(df.getRows(10, 20) === expectedAnswerForTrue) + } + + test("getRows: truncate = [3, 17]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111")) + assert(df.getRows(10, 3) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111...")) + assert(df.getRows(10, 17) === expectedAnswerForTrue) + } + + test("getRows: numRows = 0") { + val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1")) + assert(testData.select($"*").getRows(0, 20) === expectedAnswer) + } + + test("getRows: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[1, 2, 3]", "[1, 2, 3]"), + Seq("[2, 3, 4]", "[2, 3, 4]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + + test("getRows: binary") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[31 32]", "[41 42 43 2E]"), + Seq("[33 34]", "[31 32 33 34 36]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + test("showString: truncate = [0, 20]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() @@ -2261,8 +2443,123 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + test("SPARK-24165: CaseWhen/If - nullability of nested types") { + val rows = new java.util.ArrayList[Row]() + rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x"))) + rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null))) + val schema = StructType(Seq( + StructField("cond", BooleanType, true), + StructField("s", StructType(Seq( + StructField("val1", StringType, true), + StructField("val2", IntegerType, false) + )), false), + StructField("a", ArrayType(StringType, true)), + StructField("m", MapType(IntegerType, StringType, true)) + )) + + val sourceDF = spark.createDataFrame(rows, schema) + + def structWhenDF: DataFrame = sourceDF + .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") + .select('res.getField("val1")) + def arrayWhenDF: DataFrame = sourceDF + .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") + .select('res.getItem(0)) + def mapWhenDF: DataFrame = sourceDF + .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") + .select('res.getItem(0)) + + def structIfDF: DataFrame = sourceDF + .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") + .select('res.getField("val1")) + def arrayIfDF: DataFrame = sourceDF + .select(expr("if(cond, array('a', 'b'), a)") as "res") + .select('res.getItem(0)) + def mapIfDF: DataFrame = sourceDF + .select(expr("if(cond, map(0, 'a'), m)") as "res") + .select('res.getItem(0)) + + def checkResult(): Unit = { + checkAnswer(structWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(arrayWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(mapWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(structIfDF, Seq(Row("a"), Row(null))) + checkAnswer(arrayIfDF, Seq(Row("a"), Row(null))) + checkAnswer(mapIfDF, Seq(Row("a"), Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + checkResult() + // Test with cached relation, the Project will be evaluated with codegen + sourceDF.cache() + checkResult() + } + test("Uuid expressions should produce same results at retries in the same DataFrame") { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-24313: access map with binary keys") { + val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) + checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) + } + + test("SPARK-24781: Using a reference from Dataset in Filter/Sort") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + val filter1 = df.select(df("name")).filter(df("id") === 0) + val filter2 = df.select(col("name")).filter(col("id") === 0) + checkAnswer(filter1, filter2.collect()) + + val sort1 = df.select(df("name")).orderBy(df("id")) + val sort2 = df.select(col("name")).orderBy(col("id")) + checkAnswer(sort1, sort2.collect()) + } + + test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + + val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) + val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) + checkAnswer(aggPlusSort1, aggPlusSort2.collect()) + + val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) + val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) + checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) + } + } + + test("SPARK-25159: json schema inference should only trigger one job") { + withTempPath { path => + // This test is to prove that the `JsonInferSchema` does not use `RDD#toLocalIterator` which + // triggers one Spark job per RDD partition. + Seq(1 -> "a", 2 -> "b").toDF("i", "p") + // The data set has 2 partitions, so Spark will write at least 2 json files. + // Use a non-splittable compression (gzip), to make sure the json scan RDD has at least 2 + // partitions. + .write.partitionBy("p").option("compression", "gzip").json(path.getCanonicalPath) + + var numJobs = 0 + sparkContext.addSparkListener(new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + numJobs += 1 + } + }) + + val df = spark.read.json(path.getCanonicalPath) + assert(df.columns === Array("i", "p")) + assert(numJobs == 1) + } + } + + test("SPARK-23034 show rdd names in RDD scan nodes") { + val rddWithName = spark.sparkContext.parallelize(Row(1, "abc") :: Nil).setName("testRdd") + val df2 = spark.createDataFrame(rddWithName, StructType.fromDDL("c0 int, c1 string")) + val output2 = new java.io.ByteArrayOutputStream() + Console.withOut(output2) { + df2.explain(extended = false) + } + assert(output2.toString.contains("Scan ExistingRDD testRdd")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 6fe356877c268..2953425b1db49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -43,6 +43,22 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } + test("SPARK-21590: tumbling window using negative start time") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a"), + ("2016-03-27 19:39:25", 2, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 2) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -72,6 +88,20 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B Seq(Row(1), Row(1), Row(1))) } + test("SPARK-21590: tumbling window groupBy statement with negative startTime") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"), $"id") + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1))) + } + test("tumbling window with multi-column projection") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -309,4 +339,19 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } } + + test("SPARK-21590: time window in SQL with three expressions including negative start time") { + withTempTable { table => + checkAnswer( + spark.sql( + s"""select window(time, "10 seconds", 10000000, "-5 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1), + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4), + Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2) + ) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 0ee9b0edc02b2..2a0b2b85e10a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -402,4 +402,18 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: Row(10, 6000) :: Nil) } + + test("SPARK-24033: Analysis Failure of OffsetWindowFunction") { + val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") + val res = + Row(1, 1, null) :: Row (1, 2, 1) :: Row(1, 3, 2) :: Row(2, 1, null) :: Row(2, 2, 1) :: Nil + checkAnswer( + ds.withColumn("m", + lead("i", -1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + checkAnswer( + ds.withColumn("m", + lag("i", 1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 3ea398aad7375..97a843978f0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import java.sql.{Date, Timestamp} - -import scala.collection.mutable +import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} @@ -27,7 +25,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval /** * Window function testing for DataFrame API. @@ -624,4 +621,41 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { + def checkAnalysisError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) + } + + checkAnalysisError(testData2.select('a).where(rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError(testData2.where('b === 2 && rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(avg('b).as("avgb")) + .where('a > 'avgb && rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where(rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where('sumb === 5 && rank().over(Window.orderBy('a)) === 1)) + + checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError( + sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql( + s"""SELECT a, MAX(b) + |FROM testData2 + |GROUP BY a + |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0e7eaa9e88d57..538ea3c66c40e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -148,6 +148,41 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { } +case class OptionBooleanData(name: String, isGood: Option[Boolean]) + +case class OptionBooleanAggregator(colName: String) + extends Aggregator[Row, Option[Boolean], Option[Boolean]] { + + override def zero: Option[Boolean] = None + + override def reduce(buffer: Option[Boolean], row: Row): Option[Boolean] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[Boolean] + } else { + Some(row.getBoolean(index)) + } + merge(buffer, value) + } + + override def merge(b1: Option[Boolean], b2: Option[Boolean]): Option[Boolean] = { + if ((b1.isDefined && b1.get) || (b2.isDefined && b2.get)) { + Some(true) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[Boolean]): Option[Boolean] = reduction + + override def bufferEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + override def outputEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + + def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -333,4 +368,29 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { df.groupBy($"i").agg(VeryComplexResultAgg.toColumn), Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil) } + + test("SPARK-24569: Aggregator with output type Option[Boolean] creates column of type Row") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val group = df + .groupBy("name") + .agg(OptionBooleanAggregator("isGood").toColumn.alias("isGood")) + assert(df.schema == group.schema) + checkAnswer(group, Row("bob", true) :: Nil) + checkDataset(group.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } + + test("SPARK-24569: groupByKey with Aggregator of output type Option[Boolean]") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val grouped = df.groupByKey((r: Row) => r.getString(0)) + .agg(OptionBooleanAggregator("isGood").toColumn).toDF("name", "isGood") + + assert(grouped.schema == df.schema) + checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..44177e36caa01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,14 +17,28 @@ package org.apache.spark.sql +import org.scalatest.concurrent.TimeLimits +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel -class DatasetCacheSuite extends QueryTest with SharedSQLContext { +class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ + /** + * Asserts that a cached [[Dataset]] will be built using the given number of other cached results. + */ + private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: Int = 1): Unit = { + val plan = df.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).size == numOfCachesDependedUpon) + } + test("get storage level") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -96,4 +110,111 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("persist and then withColumn") { + val df = Seq(("test", 1)).toDF("s", "i") + val df2 = df.withColumn("newColumn", lit(1)) + + df.cache() + assertCached(df) + assertCached(df2) + + df.count() + assertCached(df2) + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("cache UDF result correctly") { + val expensiveUDF = udf({x: Int => Thread.sleep(5000); x}) + val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + assertCached(df2) + + // udf has been evaluated during caching, and thus should not be re-evaluated here + failAfter(3 seconds) { + df2.collect() + } + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("SPARK-24613 Cache with UDF could not be matched with subsequent dependent caches") { + val udf1 = udf({x: Int => x + 1}) + val df = spark.range(0, 10).toDF("a").withColumn("b", udf1($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + df2.cache() + + assertCacheDependency(df2) + } + + test("SPARK-24596 Non-cascading Cache Invalidation") { + val df = Seq(("a", 1), ("b", 2)).toDF("s", "i") + val df2 = df.filter('i > 1) + val df3 = df.filter('i < 2) + + df2.cache() + df.cache() + df.count() + df3.cache() + + df.unpersist() + + // df un-cached; df2 and df3's cache plan re-compiled + assert(df.storageLevel == StorageLevel.NONE) + assertCacheDependency(df2, 0) + assertCacheDependency(df3, 0) + } + + test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data reuse") { + val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x }) + val df = spark.range(0, 5).toDF("a") + val df1 = df.withColumn("b", expensiveUDF($"a")) + val df2 = df1.groupBy('a).agg(sum('b)) + val df3 = df.agg(sum('a)) + + df1.cache() + df2.cache() + df2.collect() + df3.cache() + + assertCacheDependency(df2) + + df1.unpersist(blocking = true) + + // df1 un-cached; df2's cache plan re-compiled + assert(df1.storageLevel == StorageLevel.NONE) + assertCacheDependency(df1.groupBy('a).agg(sum('b)), 0) + + val df4 = df1.groupBy('a).agg(sum('b)).agg(sum("sum(b)")) + assertCached(df4) + // reuse loaded cache + failAfter(3 seconds) { + checkDataset(df4, Row(10)) + } + + val df5 = df.agg(sum('a)).filter($"sum(a)" > 1) + assertCached(df5) + // first time use, load cache + checkDataset(df5, Row(10)) + } + + test("SPARK-24850 InMemoryRelation string representation does not include cached plan") { + val df = Seq(1).toDF("a").cache() + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + df.explain(false) + } + assert(outputStream.toString.replaceAll("#\\d+", "#x").contains( + "InMemoryRelation [a#x], StorageLevel(disk, memory, deserialized, 1 replicas)" + )) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e0f4d2ba685e1..6069f28d185e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1296,7 +1296,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { new java.sql.Timestamp(100000)) } - test("SPARK-19896: cannot have circular references in in case class") { + test("SPARK-19896: cannot have circular references in case class") { val errMsg1 = intercept[UnsupportedOperationException] { Seq(CircularReferenceClassA(null)).toDS } @@ -1425,6 +1425,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23627: provide isEmpty in DataSet") { + val ds1 = spark.emptyDataset[Int] + val ds2 = Seq(1, 2, 3).toDS() + + assert(ds1.isEmpty == true) + assert(ds2.isEmpty == false) + } + test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] @@ -1458,6 +1466,48 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val a = Seq(Some(1)).toDS + val b = Seq(Some(1.2)).toDS + val expected = Seq((Some(1), Some(1.2))).toDS + val joined = a.joinWith(b, lit(true)) + assert(joined.schema == expected.schema) + checkDataset(joined, expected.collect: _*) + } + } + + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { + val encoder = Encoders.tuple(newStringEncoder, + Encoders.tuple(newStringEncoder, newStringEncoder)) + + val data = Seq(("a", ("1", "2")), ("b", ("3", "4"))) + val rdd = sparkContext.parallelize(data) + + val ds1 = spark.createDataset(rdd) + val ds2 = spark.createDataset(rdd)(encoder) + assert(ds1.schema == ds2.schema) + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + } + + test("SPARK-24571: filtering of string values by char literal") { + val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") + checkAnswer(df.where('city === 'X'), Seq(Row("X"))) + checkAnswer( + df.where($"city".contains(new java.lang.Character('A'))), + Seq(Row("Amsterdam"))) + } + + test("SPARK-23034 show rdd names in RDD scan nodes") { + val rddWithName = spark.sparkContext.parallelize(SingleData(1) :: Nil).setName("testRdd") + val df = spark.createDataFrame(rddWithName) + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + df.explain(extended = false) + } + assert(output.toString.contains("Scan testRdd")) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 6bbf38516cdf6..3af80b36ec42c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,6 +23,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval @@ -327,6 +328,13 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + checkAnswer(df.selectExpr("months_between(t, s, true)"), Seq(Row(0.5), Row(-0.5))) + Seq(true, false).foreach { roundOff => + checkAnswer(df.select(months_between(col("t"), col("d"), roundOff)), + Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.withColumn("r", lit(false)).selectExpr("months_between(t, s, r)"), + Seq(Row(0.5), Row(-0.5))) + } } test("function last_day") { @@ -655,7 +663,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) } - test("from_utc_timestamp") { + test("from_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -672,7 +680,24 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-24 17:00:00")))) } - test("to_utc_timestamp") { + test("from_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "CET"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "PST") + ).toDF("a", "b", "c") + checkAnswer( + df.select(from_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + + test("to_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -689,4 +714,28 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("to_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "PST"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "CET") + ).toDF("a", "b", "c") + checkAnswer( + df.select(to_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + } + + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { + withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { + checkAnswer( + sql("SELECT from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')"), + Row(Timestamp.valueOf("2000-10-09 18:00:00"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 06303099f5310..4aa6afd69620b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql -import java.io.FileNotFoundException +import java.io.{File, FileNotFoundException} +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -202,4 +204,301 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + // Text file format only supports string type + test("SPARK-24691 error handling for unsupported types - text") { + withTempDir { dir => + // write path + val textDir = new File(dir, "text").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq(1).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + Seq(1.2).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + Seq(true).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + + msg = intercept[AnalysisException] { + Seq(1).toDF("a").selectExpr("struct(a)").write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((Map("Tesla" -> 3))).toDF("cars").write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((Array("Tesla", "Chevy", "Ford"))).toDF("brands") + .write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support array data type")) + + // read path + Seq("aaa").toDF.write.mode("overwrite").text(textDir) + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", DoubleType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", BooleanType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + } + } + + // Unsupported data types of csv, json, orc, and parquet are as follows; + // csv -> R/W: Null, Array, Map, Struct + // json -> R/W: Interval + // orc -> R/W: Interval, W: Null + // parquet -> R/W: Interval, Null + test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a struct") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a map") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a array") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[AnalysisException] { + Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support array data type.")) + } + } + + test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + } + + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + } + } + } + + test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("orc").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + // We expect the types below should be passed for backward-compatibility + + // Null type + var schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + + // UDT having null data + schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + } + + Seq("parquet", "csv").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + } + } + } + + test(s"SPARK-25132: case-insensitive field resolution when reading from Parquet") { + withTempDir { dir => + val format = "parquet" + val tableDir = dir.getCanonicalPath + s"/$format" + val tableName = s"spark_25132_${format}" + withTable(tableName) { + val end = 5 + val data = spark.range(end).selectExpr("id as A", "id * 2 as b", "id * 3 as B") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + data.write.format(format).mode("overwrite").save(tableDir) + } + sql(s"CREATE TABLE $tableName (a LONG, b LONG) USING $format LOCATION '$tableDir'") + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer(sql(s"select a from $tableName"), data.select("A")) + checkAnswer(sql(s"select A from $tableName"), data.select("A")) + + // RuntimeException is triggered at executor side, which is then wrapped as + // SparkException at driver side + val e1 = intercept[SparkException] { + sql(s"select b from $tableName").collect() + } + assert( + e1.getCause.isInstanceOf[RuntimeException] && + e1.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + val e2 = intercept[SparkException] { + sql(s"select B from $tableName").collect() + } + assert( + e2.getCause.isInstanceOf[RuntimeException] && + e2.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select a from $tableName"), (0 until end).map(_ => Row(null))) + checkAnswer(sql(s"select b from $tableName"), data.select("b")) + } + } + } + } +} + +object TestingUDT { + + @SQLUserDefinedType(udt = classOf[IntervalUDT]) + class IntervalData extends Serializable + + class IntervalUDT extends UserDefinedType[IntervalData] { + + override def sqlType: DataType = CalendarIntervalType + override def serialize(obj: IntervalData): Any = + throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): IntervalData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[IntervalData] = classOf[IntervalData] + } + + @SQLUserDefinedType(udt = classOf[NullUDT]) + private[sql] class NullData extends Serializable + + private[sql] class NullUDT extends UserDefinedType[NullData] { + + override def sqlType: DataType = NullType + override def serialize(obj: NullData): Any = throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): NullData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[NullData] = classOf[NullData] + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala deleted file mode 100644 index c6dd7dadc9d93..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io.File - -import scala.util.{Random, Try} - -import org.apache.spark.SparkConf -import org.apache.spark.sql.functions.monotonically_increasing_id -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - - -/** - * Benchmark to measure read performance with Filter pushdown. - */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - conf.set("orc.compression", "snappy") - conf.set("spark.sql.parquet.compression.codec", "snappy") - - private val spark = SparkSession.builder() - .master("local[1]") - .appName("FilterPushdownBenchmark") - .config(conf) - .getOrCreate() - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { - import spark.implicits._ - val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") - val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) - .withColumn("id", monotonically_increasing_id()) - - val dirORC = dir.getCanonicalPath + "/orc" - val dirParquet = dir.getCanonicalPath + "/parquet" - - df.write.mode("overwrite").orc(dirORC) - df.write.mode("overwrite").parquet(dirParquet) - - spark.read.orc(dirORC).createOrReplaceTempView("orcTable") - spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") - } - - def filterPushDownBenchmark( - values: Int, - title: String, - whereExpr: String, - selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() - } - } - } - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() - } - } - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz - - Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X - Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X - Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X - Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X - - Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X - Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X - Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X - Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X - - Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X - Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X - Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X - Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X - - Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X - Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X - Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X - - Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X - Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X - Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X - Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X - - Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X - Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X - Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X - - Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X - Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X - Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X - Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X - - Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X - Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X - Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X - Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X - - Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X - Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X - Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X - Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X - - Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X - Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X - Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X - Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X - - Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X - Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X - Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X - Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X - - Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X - Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X - Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X - Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X - */ - benchmark.run() - } - - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - val mid = numRows / 2 - - withTempPath { dir => - withTempTable("orcTable", "patquetTable") { - prepareTable(dir, numRows, width) - - Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => - val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - Seq( - s"id = $mid", - s"id <=> $mid", - s"$mid <= id AND id <= $mid", - s"${mid - 1} < id AND id < ${mid + 1}" - ).foreach { whereExpr => - val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") - - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% rows (id < ${numRows * percent / 100})", - s"id < ${numRows * percent / 100}", - selectExpr - ) - } - - Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => - filterPushDownBenchmark( - numRows, - s"Select all rows ($whereExpr)", - whereExpr, - selectExpr) - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 109fcf90a3ec9..8280a3ce39845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} @@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator { override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iteratorClass = classOf[Iterator[_]].getName - ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + ev.copy(code = + code"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63ab..44767dfc92497 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -239,7 +239,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) } - assert(e.getMessage.contains("Detected cartesian product for INNER join " + + assert(e.getMessage.contains("Detected implicit cartesian product for INNER join " + "between logical plans")) } } @@ -611,7 +611,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val e = intercept[Exception] { checkAnswer(sql(query), Nil); } - assert(e.getMessage.contains("Detected cartesian product")) + assert(e.getMessage.contains("Detected implicit cartesian product")) } cartesianQueries.foreach(checkCartesianDetection) @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) } } + + test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 100, 1, 2) + val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 00d2acc4a1d8a..f321ab86e9b7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -133,15 +133,11 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) :: Nil) } - test("from_json invalid schema") { + test("from_json - json doesn't conform to the array type") { val df = Seq("""{"a" 1}""").toDS() val schema = ArrayType(StringType) - val message = intercept[AnalysisException] { - df.select(from_json($"value", schema)) - }.getMessage - assert(message.contains( - "Input schema array must be a struct or an array of structs.")) + checkAnswer(df.select(from_json($"value", schema)), Seq(Row(null))) } test("from_json array support") { @@ -311,7 +307,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg1 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") } - assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string")) val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } @@ -326,4 +322,151 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg4.getMessage.startsWith( "A type of keys and values in map() must be string, but got")) } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() + val schema = + """ + |{ + | "type" : "map", + | "keyType" : "string", + | "valueType" : "integer", + | "valueContainsNull" : true + |} + """.stripMargin + val out = in.select(from_json($"value", schema, Map[String, String]())) + + assert(out.columns.head == "entries") + checkAnswer(out, Row(Map("a" -> 1, "b" -> 2, "c" -> 3))) + } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, new StructType().add("b", IntegerType), true) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Row(1)))) + } + + test("SPARK-24027: from_json - map>") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = "map>" + val out = in.select(from_json($"value", schema, Map.empty[String, String])) + + checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) + } + + test("SPARK-24027: roundtrip - from_json -> to_json - map") { + val json = """{"a":1,"b":2,"c":3}""" + val schema = MapType(StringType, IntegerType, true) + val out = Seq(json).toDS().select(to_json(from_json($"value", schema))) + + checkAnswer(out, Row(json)) + } + + test("SPARK-24027: roundtrip - to_json -> from_json - map") { + val in = Seq(Map("a" -> 1)).toDF() + val schema = MapType(StringType, IntegerType, true) + val out = in.select(from_json(to_json($"value"), schema)) + + checkAnswer(out, in) + } + + test("SPARK-24027: from_json - wrong map") { + val in = Seq("""{"a" 1}""").toDS() + val schema = MapType(StringType, IntegerType) + val out = in.select(from_json($"value", schema, Map[String, String]())) + + checkAnswer(out, Row(null)) + } + + test("SPARK-24027: from_json of a map with unsupported key type") { + val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType) + + checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + } + + test("SPARK-24709: infers schemas of json strings and pass them to from_json") { + val in = Seq("""{"a": [1, 2, 3]}""").toDS() + val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed") + val expected = StructType(StructField( + "parsed", + StructType(StructField( + "a", + ArrayType(LongType, true), true) :: Nil), + true) :: Nil) + + assert(out.schema == expected) + } + + test("from_json - array of primitive types") { + val df = Seq("[1, 2, 3]").toDF("a") + val schema = new ArrayType(IntegerType, false) + + checkAnswer(df.select(from_json($"a", schema)), Seq(Row(Array(1, 2, 3)))) + } + + test("from_json - array of primitive types - malformed row") { + val df = Seq("[1, 2 3]").toDF("a") + val schema = new ArrayType(IntegerType, false) + + checkAnswer(df.select(from_json($"a", schema)), Seq(Row(null))) + } + + test("from_json - array of arrays") { + val jsonDF = Seq("[[1], [2, 3], [4, 5, 6]]").toDF("a") + val schema = new ArrayType(ArrayType(IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("select json[0][0], json[1][1], json[2][2] from jsonTable"), + Seq(Row(1, 3, 6))) + } + + test("from_json - array of arrays - malformed row") { + val jsonDF = Seq("[[1], [2, 3], 4, 5, 6]]").toDF("a") + val schema = new ArrayType(ArrayType(IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("select json[0] from jsonTable"), Seq(Row(null))) + } + + test("from_json - array of structs") { + val jsonDF = Seq("""[{"a":1}, {"a":2}, {"a":3}]""").toDF("a") + val schema = new ArrayType(new StructType().add("a", IntegerType), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("select json[0], json[1], json[2] from jsonTable"), + Seq(Row(Row(1), Row(2), Row(3)))) + } + + test("from_json - array of structs - malformed row") { + val jsonDF = Seq("""[{"a":1}, {"a:2}, {"a":3}]""").toDF("a") + val schema = new ArrayType(new StructType().add("a", IntegerType), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("select json[0], json[1]from jsonTable"), Seq(Row(null, null))) + } + + test("from_json - array of maps") { + val jsonDF = Seq("""[{"a":1}, {"b":2}]""").toDF("a") + val schema = new ArrayType(MapType(StringType, IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("""select json[0], json[1] from jsonTable"""), + Seq(Row(Map("a" -> 1), Map("b" -> 2)))) + } + + test("from_json - array of maps - malformed row") { + val jsonDF = Seq("""[{"a":1} "b":2}]""").toDF("a") + val schema = new ArrayType(MapType(StringType, IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("""select json[0] from jsonTable"""), Seq(Row(null))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala index d66a6902b0510..6b90f20a94fa4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -30,21 +30,20 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self override def beforeAll() { super.beforeAll() InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } override def afterEach() { try { - resetSparkContext() + LocalSparkSession.stop(spark) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + spark = null } finally { super.afterEach() } } - - def resetSparkContext(): Unit = { - LocalSparkSession.stop(spark) - spark = null - } - } object LocalSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index cfe2e9f2dbc44..cdcea09ad9758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -54,4 +54,18 @@ class RuntimeConfigSuite extends SparkFunSuite { conf.get("k1") } } + + test("SPARK-24761: is a config parameter modifiable") { + val conf = newConf() + + // SQL configs + assert(!conf.isModifiable("spark.sql.sources.schemaStringLengthThreshold")) + assert(conf.isModifiable("spark.sql.streaming.checkpointLocation")) + // Core configs + assert(!conf.isModifiable("spark.task.cpus")) + assert(!conf.isModifiable("spark.executor.cores")) + // Invalid config parameters + assert(!conf.isModifiable("")) + assert(!conf.isModifiable("invalid config parameter")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 640affc10ee58..01dc28d70184e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -523,6 +524,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } + test("limit for skew dataframe") { + // Create a skew dataframe. + val df = testData.repartition(100).union(testData).limit(50) + // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, + // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` + // work on skew partitions. + assert(df.rdd.count() == 50L) + } + test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1689,22 +1699,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) - e = intercept[AnalysisException] { - sql(s"select id from `com.databricks.spark.avro`.`file_path`") - } - assert(e.message.contains("Failed to find data source: com.databricks.spark.avro.")) - - // data source type is case insensitive - e = intercept[AnalysisException] { - sql(s"select id from Avro.`file_path`") - } - assert(e.message.contains("Failed to find data source: avro.")) - - e = intercept[AnalysisException] { - sql(s"select id from avro.`file_path`") - } - assert(e.message.contains("Failed to find data source: avro.")) - e = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.sources.HadoopFsRelationProvider`.`file_path`") } @@ -1950,7 +1944,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") + val df = sql("SELECT a, b from testData2 order by a, b limit 1") checkAnswer(df, Row(1, 1)) checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) @@ -2704,7 +2698,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val m = intercept[AnalysisException] { sql("SELECT * FROM t, S WHERE c = C") }.message - assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) + assert( + m.contains("cannot resolve '(default.t.`c` = default.S.`C`)' due to data type mismatch")) } } } @@ -2792,4 +2787,77 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + withTable("fact_stats", "dim_stats") { + val factData = Seq((1, 1, 99, 1), (2, 2, 99, 2), (3, 1, 99, 3), (4, 2, 99, 4)) + val storeData = Seq((1, "BW", "DE"), (2, "AZ", "US")) + spark.udf.register("filterND", udf((value: Int) => value > 2).asNondeterministic) + factData.toDF("date_id", "store_id", "product_id", "units_sold") + .write.mode("overwrite").partitionBy("store_id").format("parquet").saveAsTable("fact_stats") + storeData.toDF("store_id", "state_province", "country") + .write.mode("overwrite").format("parquet").saveAsTable("dim_stats") + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id FROM + |(SELECT date_id, product_id, store_id + | FROM fact_stats WHERE filterND(date_id)) AS f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) + checkAnswer(df, Seq(Row(3, 99, 1))) + } + } + + + test("SPARK-24940: coalesce and repartition hint") { + withTempView("nums1") { + val numPartitionsSrc = 10 + spark.range(0, 100, 1, numPartitionsSrc).createOrReplaceTempView("nums1") + assert(spark.table("nums1").rdd.getNumPartitions == numPartitionsSrc) + + withTable("nums") { + sql("CREATE TABLE nums (id INT) USING parquet") + + Seq(5, 20, 2).foreach { numPartitions => + sql( + s""" + |INSERT OVERWRITE TABLE nums + |SELECT /*+ REPARTITION($numPartitions) */ * + |FROM nums1 + """.stripMargin) + assert(spark.table("nums").inputFiles.length == numPartitions) + + sql( + s""" + |INSERT OVERWRITE TABLE nums + |SELECT /*+ COALESCE($numPartitions) */ * + |FROM nums1 + """.stripMargin) + // Coalesce can not increase the number of partitions + assert(spark.table("nums").inputFiles.length == Seq(numPartitions, numPartitionsSrc).min) + } + } + } + } + + test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") { + withView("spark_25084") { + val count = 1000 + val df = spark.range(count) + val columns = (0 until 400).map{ i => s"id as id$i" } + val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") + df.selectExpr(columns : _*).createTempView("spark_25084") + assert( + spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count) + } + } + + test("SPARK-25144 'distinct' causes memory leak") { + val ds = List(Foo(Some("bar"))).toDS + val result = ds.flatMap(_.bar).distinct + result.rdd.isEmpty + } } + +case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index beac9699585d5..826408c7161e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.types.StructType * The format for input files is simple: * 1. A list of SQL queries separated by semicolon. * 2. Lines starting with -- are treated as comments and ignored. + * 3. Lines starting with --SET are used to run the file with the following set of configs. * * For example: * {{{ @@ -138,18 +139,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) + val (comments, code) = input.split("\n").partition(_.startsWith("--")) + val configSets = { + val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) + val configs = configLines.map(_.split(",").map { confAndValue => + val (conf, value) = confAndValue.span(_ != '=') + conf.trim -> value.substring(1).trim + }) + // When we are regenerating the golden files we don't need to run all the configs as they + // all need to return the same result + if (regenerateGoldenFiles && configs.nonEmpty) { + configs.take(1) + } else { + configs + } + } // List of SQL queries to run - val queries: Seq[String] = { - val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n") - // note: this is not a robust way to split queries using semicolon, but works for now. - cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + // note: this is not a robust way to split queries using semicolon, but works for now. + val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + + if (configSets.isEmpty) { + runQueries(queries, testCase.resultFile, None) + } else { + configSets.foreach { configSet => + try { + runQueries(queries, testCase.resultFile, Some(configSet)) + } catch { + case e: Throwable => + val configs = configSet.map { + case (k, v) => s"$k=$v" + } + logError(s"Error using configs: ${configs.mkString(",")}") + throw e + } + } } + } + private def runQueries( + queries: Seq[String], + resultFileName: String, + configSet: Option[Seq[(String, String)]]): Unit = { // Create a local SparkSession to have stronger isolation between different test cases. // This does not isolate catalog changes. val localSparkSession = spark.newSession() loadTestData(localSparkSession) + if (configSet.isDefined) { + // Execute the list of set operation in order to add the desired configs + val setOperations = configSet.get.map { case (key, value) => s"set $key=$value" } + logInfo(s"Setting configs: ${setOperations.mkString(", ")}") + setOperations.foreach(localSparkSession.sql) + } // Run the SQL queries preparing them for comparison. val outputs: Seq[QueryOutput] = queries.map { sql => val (schema, output) = getNormalizedResult(localSparkSession, sql) @@ -167,7 +208,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile) + val resultFile = new File(resultFileName) val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) @@ -177,7 +218,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Read back the golden file. val expectedOutputs: Seq[QueryOutput] = { - val goldenOutput = fileToString(new File(testCase.resultFile)) + val goldenOutput = fileToString(new File(resultFileName)) val segments = goldenOutput.split("-- !query.+\n") // each query has 3 segments, plus the header diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 14a565863d66c..cb562d65b6147 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), + assert(sizes.head === BigInt(128), s"expected exact size 96 for table 'test', got: ${sizes.head}") } } @@ -204,6 +204,24 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("SPARK-25028: column stats collection for null partitioning columns") { + val table = "analyze_partition_with_null" + withTempDir { dir => + withTable(table) { + sql(s""" + |CREATE TABLE $table (value string, name string) + |USING PARQUET + |PARTITIONED BY (name) + |LOCATION '${dir.toURI}'""".stripMargin) + val df = Seq(("a", null), ("b", null)).toDF("value", "name") + df.write.mode("overwrite").insertInto(table) + sql(s"ANALYZE TABLE $table PARTITION (name) COMPUTE STATISTICS") + val partitions = spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + assert(partitions.head.stats.get.rowCount.get == 2) + } + } + } + test("number format in statistics") { val numbers = Seq( BigInt(0) -> (("0.0 B", "0")), @@ -382,4 +400,32 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("Simple queries must be working, if CBO is turned on") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + withTable("TBL1", "TBL") { + import org.apache.spark.sql.functions._ + val df = spark.range(1000L).select('id, + 'id * 2 as "FLD1", + 'id * 12 as "FLD2", + lit("aaa") + 'id as "fld3") + df.write + .mode(SaveMode.Overwrite) + .bucketBy(10, "id", "FLD1", "FLD2") + .sortBy("id", "FLD1", "FLD2") + .saveAsTable("TBL") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") + val df2 = spark.sql( + """ + |SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 + |FROM tbl t1 + |JOIN tbl t2 on t1.id=t2.id + |WHERE t1.fld3 IN (-123.23,321.23) + """.stripMargin) + df2.createTempView("TBL2") + sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 31e8b0e8dede0..cbffed994bb4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.Join +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { @@ -955,4 +958,314 @@ class SubquerySuite extends QueryTest with SharedSQLContext { // before the fix this would throw AnalysisException spark.range(10).where("(id,id) in (select id, null from range(3))").count } + + test("SPARK-24085 scalar subquery in partitioning expression") { + withTable("parquet_part") { + Seq("1" -> "a", "2" -> "a", "3" -> "b", "4" -> "b") + .toDF("id_value", "id_type") + .write + .mode(SaveMode.Overwrite) + .partitionBy("id_type") + .format("parquet") + .saveAsTable("parquet_part") + checkAnswer( + sql("SELECT * FROM parquet_part WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } + + private def getNumSortsInQuery(query: String): Int = { + val plan = sql(query).queryExecution.optimizedPlan + getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum + } + + private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = { + val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression] + plan transformAllExpressions { + case s: SubqueryExpression => + subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s) + s + } + subqueryExpressions + } + + private def getNumSorts(plan: LogicalPlan): Int = { + plan.collect { case s: Sort => s }.size + } + + test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order bys + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT * + | FROM t2 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + + // nested IN + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM t2 + | WHERE c1 IN (SELECT c1 + | FROM t3 + | WHERE c1 = 1 + | ORDER BY c3) + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Complex subplan and multiple sorts + val query4 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Join in subplan + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT t2.c1 FROM t2, t3 + | WHERE t2.c1 = t3.c1 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 0) + + val query6 = + """ + |SELECT c1 + |FROM t1 + |WHERE (c1, c2) IN (SELECT c1, max(c2) + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | GROUP BY c1 + | HAVING max(c2) > 0 + | ORDER BY c1) + """.stripMargin + // The rule to remove redundant sorts is not able to remove the inner sort under + // an Aggregate operator. We only remove the top level sort. + assert(getNumSortsInQuery(query6) == 1) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query7 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query7) == 1) + + // Sort below a set operations (intersect, union) + val query8 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (( + | SELECT c1 FROM t2 + | ORDER BY c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | ORDER BY c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query8) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by exists correlated + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order by and correlated. + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM (SELECT * + | FROM t2 + | WHERE t2.c1 = t1.c1 + | ORDER BY t2.c2) t2 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested EXISTS + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM t2 + | WHERE EXISTS (SELECT c1 + | FROM t3 + | WHERE t3.c1 = t2.c1 + | ORDER BY c3) + | AND t2.c1 = t1.c1 + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query4 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 1) + + // Sort below a set operations (intersect, union) + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 1 + | ORDER BY t2.c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 2 + | ORDER BY t2.c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query5) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Two scalar subqueries in OR + val query1 = + """ + |SELECT * FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | ORDER BY max(t2.c1)) + |OR c2 = (SELECT min(t3.c2) + | FROM t3 + | WHERE t3.c1 = 1 + | ORDER BY min(t3.c2)) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // scalar subquery - groupby and having + val query2 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested scalar subquery + val query3 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Scalar subquery in projection + val query4 = + """ + |SELECT (SELECT min(c1) from t1 group by c1 order by c1) + |FROM t1 + |WHERE t1.c1 = 1 + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Limit on top of sort prevents it from being pruned. + val query5 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1) + | LIMIT 1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index bc95b4696190d..817224d1c28ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -147,7 +147,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING, |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING, |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING, - |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2)) + |`s_gmt_offset` DECIMAL(5,2), `s_tax_percentage` DECIMAL(5,2)) |USING parquet """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 21afdc7e2a33f..30dca9497ddde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -19,11 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types.{DataTypes, DoubleType} +import org.apache.spark.sql.util.QueryExecutionListener + private case class FunctionResult(f1: String, f2: String) @@ -324,4 +329,68 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) } } + + test("cached Data should be used in the write path") { + withTable("t") { + withTempPath { path => + var numTotalCachedHit = 0 + val listener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + qe.withCachedData match { + case c: CreateDataSourceTableAsSelectCommand + if c.query.isInstanceOf[InMemoryRelation] => + numTotalCachedHit += 1 + case i: InsertIntoHadoopFsRelationCommand + if i.query.isInstanceOf[InMemoryRelation] => + numTotalCachedHit += 1 + case _ => + } + } + } + spark.listenerManager.register(listener) + + val udf1 = udf({ (x: Int, y: Int) => x + y }) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", lit(10))) + df.cache() + df.write.saveAsTable("t") + assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable") + df.write.insertInto("t") + assert(numTotalCachedHit == 2, "expected to be cached in insertInto") + df.write.save(path.getCanonicalPath) + assert(numTotalCachedHit == 3, "expected to be cached in save for native") + } + } + } + + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { + val udf1 = udf({(x: Int, y: Int) => x + y}) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", udf1($"a", lit(10)))) + .withColumn("c", udf1($"a", lit(null))) + val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed + + comparePlans(df.logicalPlan, plan) + checkAnswer( + df, + Seq( + Row(0, 10, null), + Row(1, 12, null), + Row(2, 14, null))) + } + + test("SPARK-24891 Fix HandleNullInputsForUDF rule - with table") { + withTable("x") { + Seq((1, "2"), (2, "4")).toDF("a", "b").write.format("json").saveAsTable("x") + sql("insert into table x values(3, null)") + sql("insert into table x values(null, '4')") + spark.udf.register("f", (a: Int, b: String) => a + b) + val df = spark.sql("SELECT f(a, b) FROM x") + val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed + comparePlans(df.logicalPlan, plan) + checkAnswer(df, Seq(Row("12"), Row("24"), Row("3null"), Row(null))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 737eeb0af586e..41de731d41f82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -50,7 +50,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId) + new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1)) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -58,7 +58,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices - 1 Exchange") { - val coordinator = new ExchangeCoordinator(1, 100L) + val coordinator = new ExchangeCoordinator(100L) { // All bytes per partition are 0. @@ -105,7 +105,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices - 2 Exchanges") { - val coordinator = new ExchangeCoordinator(2, 100L) + val coordinator = new ExchangeCoordinator(100L) { // If there are multiple values of the number of pre-shuffle partitions, @@ -114,8 +114,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1), - new MapOutputStatistics(1, bytesByPartitionId2)) + new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)), + new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0))) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } @@ -199,7 +199,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices and enforce minimal number of reducers") { - val coordinator = new ExchangeCoordinator(2, 100L, Some(2)) + val coordinator = new ExchangeCoordinator(100L, Some(2)) { // The minimal number of post-shuffle partitions is not enforced because @@ -480,4 +480,17 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { withSparkSession(test, 6144, minNumPostShufflePartitions) } } + + test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { + val test = { spark: SparkSession => + spark.sql("SET spark.sql.exchange.reuse=true") + val df = spark.range(1).selectExpr("id AS key", "id AS value") + val resultDf = df.join(df, "key").join(df, "key") + val sparkPlan = resultDf.queryExecution.executedPlan + assert(sparkPlan.collect { case p: ReusedExchangeExec => p }.length == 1) + assert(sparkPlan.collect { case p @ ShuffleExchangeExec(_, _, Some(c)) => p }.length == 3) + checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) + } + withSparkSession(test, 4, None) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 40915a102bab0..3db89ecfad9fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext { testPartialAggregationPlan(query) } + test("mixed aggregates with same distinct columns") { + def assertNoExpand(plan: SparkPlan): Unit = { + assert(plan.collect { case e: ExpandExec => e }.isEmpty) + } + + withTempView("v") { + Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v") + // one distinct column + val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i") + assertNoExpand(query1.queryExecution.executedPlan) + + // 2 distinct columns + val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i") + assertNoExpand(query2.queryExecution.executedPlan) + + // 2 distinct columns with different order + val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") + assertNoExpand(query3.queryExecution.executedPlan) + } + } + test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType]): Unit = { withTempView("testLimit") { @@ -194,7 +215,19 @@ class PlannerSuite extends SharedSQLContext { test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] - assert(planned.child.isInstanceOf[CollectLimitExec]) + assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) + } + + test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { + val query0 = testData.select('value).orderBy('key).limit(100) + val planned0 = query0.queryExecution.executedPlan + assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + + val query1 = testData.select('value).orderBy('key).limit(2000) + val planned1 = query1.queryExecution.executedPlan + assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) + } } test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { @@ -229,7 +262,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } { @@ -244,7 +277,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } } @@ -591,7 +624,7 @@ class PlannerSuite extends SharedSQLContext { dataType = LongType, nullable = false ) (exprId = exprId, - qualifier = Some("col1_qualifier") + qualifier = Seq("col1_qualifier") ) val attribute2 = @@ -621,6 +654,115 @@ class PlannerSuite extends SharedSQLContext { requiredOrdering = Seq(orderingA, orderingB), shouldHaveSort = true) } + + test("SPARK-24242: RangeExec should have correct output ordering and partitioning") { + val df = spark.range(10) + val rangeExec = df.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + val range = df.queryExecution.optimizedPlan.collect { + case r: Range => r + } + assert(rangeExec.head.outputOrdering == range.head.outputOrdering) + assert(rangeExec.head.outputPartitioning == + RangePartitioning(rangeExec.head.outputOrdering, df.rdd.getNumPartitions)) + + val rangeInOnePartition = spark.range(1, 10, 1, 1) + val rangeExecInOnePartition = rangeInOnePartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInOnePartition.head.outputPartitioning == SinglePartition) + + val rangeInZeroPartition = spark.range(-10, -9, -20, 1) + val rangeExecInZeroPartition = rangeInZeroPartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) + } + + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprA)) + assert(rightKeys == Seq(exprB, exprC)) + case _ => fail() + } + } + + test("SPARK-24500: create union with stream of children") { + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } + + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + + "and InMemoryTableScanExec") { + def checkOutputPartitioningRewrite( + plans: Seq[SparkPlan], + expectedPartitioningClass: Class[_]): Unit = { + assert(plans.size == 1) + val plan = plans.head + val partitioning = plan.outputPartitioning + assert(partitioning.getClass == expectedPartitioningClass) + val partitionedAttrs = partitioning.asInstanceOf[Expression].references + assert(partitionedAttrs.subsetOf(plan.outputSet)) + } + + def checkReusedExchangeOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val reusedExchange = df.queryExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + } + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) + } + + def checkInMemoryTableScanOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val inMemoryScan = df.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) + } + + // ReusedExchange is HashPartitioning + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) + + // ReusedExchange is RangePartitioning + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) + + // InMemoryTableScan is HashPartitioning + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) + + // InMemoryTableScan is RangePartitioning + spark.range(1, 100, 1, 10).toDF().persist() + checkInMemoryTableScanOutputPartitioningRewrite( + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) + + // InMemoryTableScan is PartitioningCollection + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), + classOf[PartitioningCollection]) + } + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala index aaf51b5b90111..d088e24e53bfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.IntegerType /** * Tests for the sameResult function for [[SparkPlan]]s. @@ -58,4 +61,16 @@ class SameResultSuite extends QueryTest with SharedSQLContext { val df4 = spark.range(10).agg(sumDistinct($"id")) assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan)) } + + test("Canonicalized result is case-insensitive") { + val a = AttributeReference("A", IntegerType)() + val b = AttributeReference("B", IntegerType)() + val planUppercase = Project(Seq(a), LocalRelation(a, b)) + + val c = AttributeReference("a", IntegerType)() + val d = AttributeReference("b", IntegerType)() + val planLowercase = Project(Seq(c), LocalRelation(c, d)) + + assert(planUppercase.sameResult(planLowercase)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala new file mode 100644 index 0000000000000..05f7e3ce83880 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ + +class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { + private val ignoredField = StructField("col1", StringType, nullable = false) + + // The test schema as a tree string, i.e. `schema.treeString` + // root + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field1: integer (nullable = true) + // | |-- field6: struct (nullable = true) + // | | |-- subfield1: string (nullable = false) + // | | |-- subfield2: string (nullable = true) + // | |-- field7: struct (nullable = true) + // | | |-- subfield1: struct (nullable = true) + // | | | |-- subsubfield1: integer (nullable = true) + // | | | |-- subsubfield2: integer (nullable = true) + // | |-- field9: map (nullable = true) + // | | |-- key: string + // | | |-- value: integer (valueContainsNull = false) + private val nestedComplex = StructType(ignoredField :: + StructField("col2", StructType( + StructField("field1", IntegerType) :: + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: + StructField("subfield2", StringType) :: Nil)) :: + StructField("field7", StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: + StructField("field9", + MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) :: Nil) + + test("SelectedField should not match an attribute reference") { + val testRelation = LocalRelation(nestedComplex.toAttributes) + assertResult(None)(unapplySelect("col1", testRelation)) + assertResult(None)(unapplySelect("col1 as foo", testRelation)) + assertResult(None)(unapplySelect("col2", testRelation)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field2: array (nullable = true) + // | | |-- element: integer (containsNull = false) + // | |-- field3: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: integer (nullable = true) + // | | | |-- subfield3: array (nullable = true) + // | | | | |-- element: integer (containsNull = true) + private val structOfArray = StructType(ignoredField :: + StructField("col2", StructType( + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) + :: Nil)) + :: Nil) + + testSelect(structOfArray, "col2.field2", "col2.field2[0] as foo") { + StructField("col2", StructType( + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field9", "col2.field9['foo'] as foo") { + StructField("col2", StructType( + StructField("field9", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) + } + + testSelect(structOfArray, "col2.field3.subfield3", "col2.field3[0].subfield3 as foo", + "col2.field3.subfield3[0] as foo", "col2.field3[0].subfield3[0] as foo") { + StructField("col2", StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structOfArray, "col2.field3.subfield1") { + StructField("col2", StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: Nil)), nullable = false) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field4: map (nullable = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: array (nullable = true) + // | | | | |-- element: integer (containsNull = false) + // | |-- field8: map (nullable = true) + // | | |-- key: string + // | | |-- value: array (valueContainsNull = false) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: array (nullable = true) + // | | | | | |-- element: integer (containsNull = false) + private val structWithMap = StructType( + ignoredField :: + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil + ), valueContainsNull = false)) :: + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil) + ), valueContainsNull = false)) :: Nil + )) :: Nil + ) + + testSelect(structWithMap, "col2.field4['foo'].subfield1 as foo") { + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: Nil), valueContainsNull = false)) :: Nil)) + } + + testSelect(structWithMap, + "col2.field4['foo'].subfield2 as foo", "col2.field4['foo'].subfield2[0] as foo") { + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) + :: Nil), valueContainsNull = false)) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field5: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: struct (nullable = false) + // | | | | |-- subsubfield1: integer (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + // | | | |-- subfield2: struct (nullable = true) + // | | | | |-- subsubfield1: struct (nullable = true) + // | | | | | |-- subsubsubfield1: string (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + private val structWithArray = StructType( + ignoredField :: + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) :: + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)), nullable = false) :: Nil) + ) :: Nil + ) + + testSelect(structWithArray, "col2.field5.subfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithArray, "col2.field5.subfield1.subsubfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: Nil), nullable = false) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithArray, "col2.field5.subfield2.subsubfield1.subsubsubfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: Nil)) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithMap, "col2.field8['foo'][0].subfield1 as foo") { + StructField("col2", StructType( + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: Nil)), valueContainsNull = false)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field1") { + StructField("col2", StructType( + StructField("field1", IntegerType) :: Nil)) + } + + testSelect(nestedComplex, "col2.field6") { + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: + StructField("subfield2", StringType) :: Nil)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field6.subfield1") { + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: Nil)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field7.subfield1") { + StructField("col2", StructType( + StructField("field7", StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col3: array (nullable = false) + // | |-- element: struct (containsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + private val arrayWithStructAndMap = StructType(Array( + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), containsNull = false), nullable = false) + )) + + testSelect(arrayWithStructAndMap, "col3.field1.subfield1") { + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) + :: Nil), containsNull = false), nullable = false) + } + + testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") { + StructField("col3", ArrayType(StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), containsNull = false), nullable = false) + } + + // |-- col1: string (nullable = false) + // |-- col4: map (nullable = false) + // | |-- key: string + // | |-- value: struct (valueContainsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + private val col4 = StructType(Array(ignoredField, + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), valueContainsNull = false), nullable = false) + )) + + testSelect(col4, "col4['foo'].field1.subfield1 as foo") { + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) + :: Nil), valueContainsNull = false), nullable = false) + } + + testSelect(col4, "col4['foo'].field2['bar'] as foo") { + StructField("col4", MapType(StringType, StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), valueContainsNull = false), nullable = false) + } + + // |-- col1: string (nullable = false) + // |-- col5: array (nullable = true) + // | |-- element: map (containsNull = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + private val arrayOfStruct = StructType(Array(ignoredField, + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + )) + + testSelect(arrayOfStruct, "col5[0]['foo'].field1.subfield1 as foo") { + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + } + + // |-- col1: string (nullable = false) + // |-- col6: map (nullable = true) + // | |-- key: string + // | |-- value: array (valueContainsNull = true) + // | | |-- element: struct (containsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + private val mapOfArray = StructType(Array(ignoredField, + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))))) + + testSelect(mapOfArray, "col6['foo'][0].field1.subfield1 as foo") { + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false))) + } + + // An array with a struct with a different fields + // |-- col1: string (nullable = false) + // |-- col7: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: integer (nullable = false) + // | | |-- field2: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | |-- field3: array (nullable = true) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = false) + private val arrayWithMultipleFields = StructType(Array(ignoredField, + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))))) + + testSelect(arrayWithMultipleFields, + "col7.field1", "col7[0].field1 as foo", "col7.field1[0] as foo") { + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: Nil))) + } + + testSelect(arrayWithMultipleFields, "col7.field2.subfield1") { + StructField("col7", ArrayType(StructType( + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil))) + } + + testSelect(arrayWithMultipleFields, "col7.field3.subfield1") { + StructField("col7", ArrayType(StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) + } + + // Array with a nested int array + // |-- col1: string (nullable = false) + // |-- col8: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: array (nullable = false) + // | | | |-- element: integer (containsNull = false) + private val arrayOfArray = StructType(Array(ignoredField, + StructField("col8", + ArrayType(StructType(Array(StructField("field1", + ArrayType(IntegerType, containsNull = false), nullable = false)))) + ))) + + testSelect(arrayOfArray, "col8.field1", + "col8[0].field1 as foo", + "col8.field1[0] as foo", + "col8[0].field1[0] as foo") { + StructField("col8", ArrayType(StructType( + StructField("field1", ArrayType(IntegerType, containsNull = false), nullable = false) + :: Nil))) + } + + def assertResult(expected: StructField)(actual: StructField)(selectExpr: String): Unit = { + try { + super.assertResult(expected)(actual) + } catch { + case ex: TestFailedException => + // Print some helpful diagnostics in the case of failure + alert("Expected SELECT \"" + selectExpr + "\" to select the schema\n" + + indent(StructType(expected :: Nil).treeString) + + indent("but it actually selected\n") + + indent(StructType(actual :: Nil).treeString) + + indent("Note that expected.dataType.sameType(actual.dataType) = " + + expected.dataType.sameType(actual.dataType))) + throw ex + } + } + + // Test that the given SELECT expressions prune the test schema to the single-column schema + // defined by the given field + private def testSelect(inputSchema: StructType, selectExprs: String*) + (expected: StructField) { + test(s"SELECT ${selectExprs.map(s => s""""$s"""").mkString(", ")} should select the schema\n" + + indent(StructType(expected :: Nil).treeString)) { + for (selectExpr <- selectExprs) { + assertSelect(selectExpr, expected, inputSchema) + } + } + } + + private def assertSelect(expr: String, expected: StructField, inputSchema: StructType): Unit = { + val relation = LocalRelation(inputSchema.toAttributes) + unapplySelect(expr, relation) match { + case Some(field) => + assertResult(expected)(field)(expr) + case None => + val failureMessage = + "Failed to select a field from " + expr + ". " + + "Expected:\n" + + StructType(expected :: Nil).treeString + fail(failureMessage) + } + } + + private def unapplySelect(expr: String, relation: LocalRelation) = { + val parsedExpr = parseAsCatalystExpression(Seq(expr)).head + val select = relation.select(parsedExpr) + val analyzed = select.analyze + SelectedField.unapply(analyzed.expressions.head) + } + + private def parseAsCatalystExpression(exprs: Seq[String]) = { + exprs.map(CatalystSqlParser.parseExpression(_) match { + case namedExpr: NamedExpression => namedExpr + }) + } + + // Indent every line in `string` by four spaces + private def indent(string: String) = string.replaceAll("(?m)^", " ") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala index 750d9e4adf8b4..34dc6f37c0e4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.SparkEnv import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext @@ -33,4 +34,20 @@ class SparkPlanSuite extends QueryTest with SharedSQLContext { intercept[IllegalStateException] { plan.executeTake(1) } } + test("SPARK-23731 plans should be canonicalizable after being (de)serialized") { + withTempPath { path => + spark.range(1).write.parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + val fileSourceScanExec = + df.queryExecution.sparkPlan.collectFirst { case p: FileSourceScanExec => p }.get + val serializer = SparkEnv.get.serializer.newInstance() + val readback = + serializer.deserialize[FileSourceScanExec](serializer.serialize(fileSourceScanExec)) + try { + readback.canonicalized + } catch { + case e: Throwable => fail("FileSourceScanExec was not canonicalizable", e) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 107a2f7109793..28a060aff47b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -366,4 +366,15 @@ class SparkSqlParserSuite extends AnalysisTest { "SELECT a || b || c FROM t", Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) } + + test("SPARK-25046 Fix Alter View ... As Insert Into Table") { + // Single insert query + intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: ALTER VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("ALTER VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 7e317a4d80265..0a1c94cc4ccf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -22,6 +22,7 @@ import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -31,10 +32,19 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 + private val originalLimitFlatGlobalLimit = SQLConf.get.getConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT) + protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) + + // Disable the optimization to make Sort-Limit match `TakeOrderedAndProject` semantics. + SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) + } + + protected override def afterAll() = { + SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) } private def generateRandomInputData(): DataFrame = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 3e31d22e15c0e..5c15ecd42fa0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.mockito.Mockito._ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null + private var taskContext: TaskContext = null + def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + taskContext = mock(classOf[TaskContext]) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity, PAGE_SIZE_BYTES ) @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity PAGE_SIZE_BYTES ) @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, StructType(Nil), StructType(Nil), - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index bf588d3bb7841..c882a9dd2148c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -231,7 +231,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` // which has duplicated keys and the number of entries exceeds its capacity. try { - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + TaskContext.setTaskContext(context) new UnsafeKVExternalSorter( schema, schema, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index a3ae93810aa3c..d305ce3e698ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -21,15 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} import java.util.Properties import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter /** @@ -43,7 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea } } -class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val converter = unsafeRowConverter(schema) @@ -58,7 +56,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("toUnsafeRow() test helper method") { - // This currently doesnt work because the generic getter throws an exception. + // This currently doesn't work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) @@ -97,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-10466: external sorter spilling with unsafe row serializer") { - var sc: SparkContext = null - var outputFile: File = null - val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten - Utils.tryWithSafeFinally { - val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1") - .set("spark.shuffle.sort.bypassMergeThreshold", "0") - .set("spark.testing.memory", "80000") - - sc = new SparkContext("local", "test", conf) - outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") - // prepare data - val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 10000).iterator.map { i => - (i, converter(Row(i))) - } - val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - - val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( - taskContext, - partitioner = Some(new HashPartitioner(10)), - serializer = new UnsafeRowSerializer(numFields = 1)) - - // Ensure we spilled something and have to merge them later - assert(sorter.numSpills === 0) - sorter.insertAll(data) - assert(sorter.numSpills > 0) + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.testing.memory", "80000") + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() + val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + outputFile.deleteOnExit() + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 10000).iterator.map { i => + (i, converter(Row(i))) + } + val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - // Merging spilled files should not throw assertion error - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) - } { - // Clean up - if (sc != null) { - sc.stop() - } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + taskContext, + partitioner = Some(new HashPartitioner(10)), + serializer = new UnsafeRowSerializer(numFields = 1)) - // restore the spark env - SparkEnv.set(oldEnv) + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) - if (outputFile != null) { - outputFile.delete() - } - } + // Merging spilled files should not throw assertion error + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } test("SPARK-10403: unsafe row serializer with SortShuffleManager") { val conf = new SparkConf().set("spark.shuffle.manager", "sort") - sc = new SparkContext("local", "test", conf) + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) - val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) - .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val rowsRDD = spark.sparkContext.parallelize( + Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)) + ).asInstanceOf[RDD[Product2[Int, InternalRow]]] val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rowsRDD, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f1..b714dcd5269fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -51,12 +51,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = spark.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) - assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1))) } test("BroadcastHashJoin should be included in WholeStageCodegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 261df06100aef..c36872a6a5289 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.arrow -import java.io.File +import java.io.{ByteArrayOutputStream, DataOutputStream, File} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.ipc.JsonFileReader -import org.apache.arrow.vector.util.Validator +import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} @@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowPayload.collect() - assert(arrowPayloads.nonEmpty) - assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val arrowBatches = indexData.toArrowBatchRdd.collect() + assert(arrowBatches.nonEmpty) + assert(arrowBatches.length == indexData.rdd.getNumPartitions) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) val rowCount = arrowRecordBatches.map(_.getLength).sum assert(rowCount === indexData.count()) arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) @@ -1153,9 +1153,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val arrowPayloads = testData2.toArrowPayload.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) + val arrowBatches = testData2.toArrowBatchRdd.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches + assert(arrowBatches.length === 2) val schema = testData2.schema val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") @@ -1163,25 +1163,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { Files.write(json1, tempFile1, StandardCharsets.UTF_8) Files.write(json2, tempFile2, StandardCharsets.UTF_8) - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) + validateConversion(schema, arrowBatches(0), tempFile1) + validateConversion(schema, arrowBatches(1), tempFile2) } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() - assert(arrowPayload.isEmpty) + val arrowBatches = spark.emptyDataFrame.toArrowBatchRdd.collect() + assert(arrowBatches.isEmpty) val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() - assert(filteredArrowPayload.isEmpty) + val filteredArrowBatches = filteredDF.filter("i < 0").toArrowBatchRdd.collect() + assert(filteredArrowBatches.isEmpty) } test("empty partition collect") { val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() - assert(arrowPayloads.length === 1) + val arrowBatches = emptyPart.toArrowBatchRdd.collect() + assert(arrowBatches.length === 1) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) assert(arrowRecordBatches.head.getLength == 1) arrowRecordBatches.foreach(_.close()) allocator.close() @@ -1192,10 +1192,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val maxRecordsPerBatch = 3 spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowPayload.collect() - assert(arrowPayloads.length >= 4) + val arrowBatches = df.toArrowBatchRdd.collect() + assert(arrowBatches.length >= 4) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) var recordCount = 0 arrowRecordBatches.foreach { batch => assert(batch.getLength > 0) @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowBatchRdd.collect() } + runUnsupported { complexData.toArrowBatchRdd.collect() } } test("test Arrow Validator") { @@ -1318,7 +1318,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { + test("roundtrip arrow batches") { val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) @@ -1326,10 +1326,41 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) - val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) - assert(schema == outputRowIter.schema) + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { + val inputRows = (0 until 9).map(InternalRow(_)) :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val ctx = TaskContext.empty() + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + + // Write batches to Arrow stream format as a byte array + val out = new ByteArrayOutputStream() + Utils.tryWithResource(new DataOutputStream(out)) { dataOut => + val writer = new ArrowBatchStreamWriter(schema, dataOut, null) + writer.writeBatches(batchIter) + writer.end() + } + + // Read Arrow stream into batches, then convert back to rows + val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray) + val readBatches = ArrowConverters.getBatchesFromStream(in) + val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, null, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => @@ -1348,15 +1379,15 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { private def collectAndValidate( df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) + validateConversion(df.schema, batchBytes, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, - arrowPayload: ArrowPayload, + batchBytes: Array[Byte], jsonFile: File, timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) @@ -1368,7 +1399,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) - val arrowRecordBatch = arrowPayload.loadBatch(allocator) + val arrowRecordBatch = ArrowConverters.loadBatch(batchBytes, allocator) vectorLoader.load(arrowRecordBatch) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala new file mode 100644 index 0000000000000..8711f5a8fa1ce --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -0,0 +1,836 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnVector +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure data source read performance. + * To run this: + * spark-submit --class + */ +object DataSourceReadBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceReadBenchmark") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + spark.conf.set(SQLConf.ORC_COPY_BATCH_TO_SPARK.key, "false") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val testDf = if (partition.isDefined) { + df.write.partitionBy(partition.get) + } else { + df.write + } + + saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") + saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") + saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet") + saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc") + } + + private def saveAsCsvTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").option("header", true).csv(dir) + spark.read.option("header", true).csv(dir).createOrReplaceTempView("csvTable") + } + + private def saveAsJsonTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").json(dir) + spark.read.json(dir).createOrReplaceTempView("jsonTable") + } + + private def saveAsParquetTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + } + + private def saveAsOrcTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").orc(dir) + spark.read.orc(dir).createOrReplaceTempView("orcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + // Benchmarks driving reader component directly. + val parquetReaderBenchmark = new Benchmark( + s"Parquet Reader Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + sqlBenchmark.addCase("SQL Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + sqlBenchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 22964 / 23096 0.7 1460.0 1.0X + SQL Json 8469 / 8593 1.9 538.4 2.7X + SQL Parquet Vectorized 164 / 177 95.8 10.4 139.9X + SQL Parquet MR 1687 / 1706 9.3 107.2 13.6X + SQL ORC Vectorized 191 / 197 82.3 12.2 120.2X + SQL ORC Vectorized with copy 215 / 219 73.2 13.7 106.9X + SQL ORC MR 1392 / 1412 11.3 88.5 16.5X + + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 24090 / 24097 0.7 1531.6 1.0X + SQL Json 8791 / 8813 1.8 558.9 2.7X + SQL Parquet Vectorized 204 / 212 77.0 13.0 117.9X + SQL Parquet MR 1813 / 1850 8.7 115.3 13.3X + SQL ORC Vectorized 226 / 230 69.7 14.4 106.7X + SQL ORC Vectorized with copy 295 / 298 53.3 18.8 81.6X + SQL ORC MR 1526 / 1549 10.3 97.1 15.8X + + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 25637 / 25791 0.6 1629.9 1.0X + SQL Json 9532 / 9570 1.7 606.0 2.7X + SQL Parquet Vectorized 181 / 191 86.8 11.5 141.5X + SQL Parquet MR 2210 / 2227 7.1 140.5 11.6X + SQL ORC Vectorized 309 / 317 50.9 19.6 83.0X + SQL ORC Vectorized with copy 316 / 322 49.8 20.1 81.2X + SQL ORC MR 1650 / 1680 9.5 104.9 15.5X + + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 31617 / 31764 0.5 2010.1 1.0X + SQL Json 12440 / 12451 1.3 790.9 2.5X + SQL Parquet Vectorized 284 / 315 55.4 18.0 111.4X + SQL Parquet MR 2382 / 2390 6.6 151.5 13.3X + SQL ORC Vectorized 398 / 403 39.5 25.3 79.5X + SQL ORC Vectorized with copy 410 / 413 38.3 26.1 77.1X + SQL ORC MR 1783 / 1813 8.8 113.4 17.7X + + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 26679 / 26742 0.6 1696.2 1.0X + SQL Json 12490 / 12541 1.3 794.1 2.1X + SQL Parquet Vectorized 174 / 183 90.4 11.1 153.3X + SQL Parquet MR 2201 / 2223 7.1 140.0 12.1X + SQL ORC Vectorized 415 / 429 37.9 26.4 64.3X + SQL ORC Vectorized with copy 422 / 428 37.2 26.9 63.2X + SQL ORC MR 1767 / 1773 8.9 112.3 15.1X + + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 34223 / 34324 0.5 2175.8 1.0X + SQL Json 17784 / 17785 0.9 1130.7 1.9X + SQL Parquet Vectorized 277 / 283 56.7 17.6 123.4X + SQL Parquet MR 2356 / 2386 6.7 149.8 14.5X + SQL ORC Vectorized 533 / 536 29.5 33.9 64.2X + SQL ORC Vectorized with copy 541 / 546 29.1 34.4 63.3X + SQL ORC MR 2166 / 2177 7.3 137.7 15.8X + */ + sqlBenchmark.run() + + // Driving the parquet reader in batch mode directly. + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + parquetReaderBenchmark.addCase("ParquetReader Vectorized") { _ => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (ColumnVector, Int) => Unit = dataType match { + case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) + case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) + case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) + case LongType => (col: ColumnVector, i: Int) => longSum += col.getLong(i) + case FloatType => (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) + case DoubleType => (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + val col = batch.column(0) + while (reader.nextBatch()) { + val numRows = batch.numRows() + var i = 0 + while (i < numRows) { + if (!col.isNullAt(i)) aggregateValue(col, i) + i += 1 + } + } + } finally { + reader.close() + } + } + } + + // Decoding in vectorized but having the reader return rows. + parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (InternalRow) => Unit = dataType match { + case ByteType => (col: InternalRow) => longSum += col.getByte(0) + case ShortType => (col: InternalRow) => longSum += col.getShort(0) + case IntegerType => (col: InternalRow) => longSum += col.getInt(0) + case LongType => (col: InternalRow) => longSum += col.getLong(0) + case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) + case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val it = batch.rowIterator() + while (it.hasNext) { + val record = it.next() + if (!record.isNullAt(0)) aggregateValue(record) + } + } + } finally { + reader.close() + } + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 198 / 202 79.4 12.6 1.0X + ParquetReader Vectorized -> Row 119 / 121 132.3 7.6 1.7X + + + Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 282 / 287 55.8 17.9 1.0X + ParquetReader Vectorized -> Row 246 / 247 64.0 15.6 1.1X + + + Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 258 / 262 60.9 16.4 1.0X + ParquetReader Vectorized -> Row 259 / 260 60.8 16.5 1.0X + + + Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 361 / 369 43.6 23.0 1.0X + ParquetReader Vectorized -> Row 361 / 371 43.6 22.9 1.0X + + + Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 253 / 261 62.2 16.1 1.0X + ParquetReader Vectorized -> Row 254 / 256 61.9 16.2 1.0X + + + Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 357 / 364 44.0 22.7 1.0X + ParquetReader Vectorized -> Row 358 / 366 44.0 22.7 1.0X + */ + parquetReaderBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(c1), sum(length(c2)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(c1), sum(length(c2)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 27145 / 27158 0.4 2588.7 1.0X + SQL Json 12969 / 13337 0.8 1236.8 2.1X + SQL Parquet Vectorized 2419 / 2448 4.3 230.7 11.2X + SQL Parquet MR 4631 / 4633 2.3 441.7 5.9X + SQL ORC Vectorized 2412 / 2465 4.3 230.0 11.3X + SQL ORC Vectorized with copy 2633 / 2675 4.0 251.1 10.3X + SQL ORC MR 4280 / 4350 2.4 408.2 6.3X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("select cast((value % 200) + 10000 as STRING) as c1 from t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c1)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c1)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("select sum(length(c1)) from orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 17345 / 17424 0.6 1654.1 1.0X + SQL Json 8639 / 8664 1.2 823.9 2.0X + SQL Parquet Vectorized 839 / 854 12.5 80.0 20.7X + SQL Parquet MR 1771 / 1775 5.9 168.9 9.8X + SQL ORC Vectorized 550 / 569 19.1 52.4 31.6X + SQL ORC Vectorized with copy 785 / 849 13.4 74.9 22.1X + SQL ORC MR 2168 / 2202 4.8 206.7 8.0X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Data column - CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + benchmark.addCase("Data column - Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + benchmark.addCase("Data column - Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + benchmark.addCase("Data column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + benchmark.addCase("Data column - ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Data column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Data column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - CSV") { _ => + spark.sql("select sum(p) from csvTable").collect() + } + + benchmark.addCase("Partition column - Json") { _ => + spark.sql("select sum(p) from jsonTable").collect() + } + + benchmark.addCase("Partition column - Parquet Vectorized") { _ => + spark.sql("select sum(p) from parquetTable").collect() + } + + benchmark.addCase("Partition column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p) from parquetTable").collect() + } + } + + benchmark.addCase("Partition column - ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + + benchmark.addCase("Partition column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - CSV") { _ => + spark.sql("select sum(p), sum(id) from csvTable").collect() + } + + benchmark.addCase("Both columns - Json") { _ => + spark.sql("select sum(p), sum(id) from jsonTable").collect() + } + + benchmark.addCase("Both columns - Parquet Vectorized") { _ => + spark.sql("select sum(p), sum(id) from parquetTable").collect() + } + + benchmark.addCase("Both columns - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p), sum(id) from parquetTable").collect + } + } + + benchmark.addCase("Both columns - ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Both column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Data column - CSV 32613 / 32841 0.5 2073.4 1.0X + Data column - Json 13343 / 13469 1.2 848.3 2.4X + Data column - Parquet Vectorized 302 / 318 52.1 19.2 108.0X + Data column - Parquet MR 2908 / 2924 5.4 184.9 11.2X + Data column - ORC Vectorized 412 / 425 38.1 26.2 79.1X + Data column - ORC Vectorized with copy 442 / 446 35.6 28.1 73.8X + Data column - ORC MR 2390 / 2396 6.6 152.0 13.6X + Partition column - CSV 9626 / 9683 1.6 612.0 3.4X + Partition column - Json 10909 / 10923 1.4 693.6 3.0X + Partition column - Parquet Vectorized 69 / 76 228.4 4.4 473.6X + Partition column - Parquet MR 1898 / 1933 8.3 120.7 17.2X + Partition column - ORC Vectorized 67 / 74 236.0 4.2 489.4X + Partition column - ORC Vectorized with copy 65 / 72 241.9 4.1 501.6X + Partition column - ORC MR 1743 / 1749 9.0 110.8 18.7X + Both columns - CSV 35523 / 35552 0.4 2258.5 0.9X + Both columns - Json 13676 / 13681 1.2 869.5 2.4X + Both columns - Parquet Vectorized 317 / 326 49.5 20.2 102.7X + Both columns - Parquet MR 3333 / 3336 4.7 211.9 9.8X + Both columns - ORC Vectorized 441 / 446 35.6 28.1 73.9X + Both column - ORC Vectorized with copy 517 / 524 30.4 32.9 63.1X + Both columns - ORC MR 2574 / 2577 6.1 163.6 12.7X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + val benchmark = new Benchmark("String with Nulls Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c2)) from csvTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c2)) from jsonTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + benchmark.addCase("ParquetReader Vectorized") { num => + var sum = 0 + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + val row = rowIterator.next() + val value = row.getUTF8String(0) + if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + } + } + } finally { + reader.close() + } + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 14875 / 14920 0.7 1418.6 1.0X + SQL Json 10974 / 10992 1.0 1046.5 1.4X + SQL Parquet Vectorized 1711 / 1750 6.1 163.2 8.7X + SQL Parquet MR 3838 / 3884 2.7 366.0 3.9X + ParquetReader Vectorized 1155 / 1168 9.1 110.2 12.9X + SQL ORC Vectorized 1341 / 1380 7.8 127.9 11.1X + SQL ORC Vectorized with copy 1659 / 1716 6.3 158.2 9.0X + SQL ORC MR 3594 / 3634 2.9 342.7 4.1X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 17219 / 17264 0.6 1642.1 1.0X + SQL Json 8843 / 8864 1.2 843.3 1.9X + SQL Parquet Vectorized 1169 / 1178 9.0 111.4 14.7X + SQL Parquet MR 2676 / 2697 3.9 255.2 6.4X + ParquetReader Vectorized 1068 / 1071 9.8 101.8 16.1X + SQL ORC Vectorized 1319 / 1319 7.9 125.8 13.1X + SQL ORC Vectorized with copy 1638 / 1639 6.4 156.2 10.5X + SQL ORC MR 3230 / 3257 3.2 308.1 5.3X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 13976 / 14053 0.8 1332.8 1.0X + SQL Json 5166 / 5176 2.0 492.6 2.7X + SQL Parquet Vectorized 274 / 282 38.2 26.2 50.9X + SQL Parquet MR 1553 / 1555 6.8 148.1 9.0X + ParquetReader Vectorized 241 / 246 43.5 23.0 57.9X + SQL ORC Vectorized 476 / 479 22.0 45.4 29.3X + SQL ORC Vectorized with copy 584 / 588 17.9 55.7 23.9X + SQL ORC MR 1720 / 1734 6.1 164.1 8.1X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql(s"SELECT sum(c$middle) FROM csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql(s"SELECT sum(c$middle) FROM jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 3478 / 3481 0.3 3316.4 1.0X + SQL Json 2646 / 2654 0.4 2523.6 1.3X + SQL Parquet Vectorized 67 / 72 15.8 63.5 52.2X + SQL Parquet MR 207 / 214 5.1 197.6 16.8X + SQL ORC Vectorized 69 / 76 15.2 66.0 50.3X + SQL ORC Vectorized with copy 70 / 76 15.0 66.5 49.9X + SQL ORC MR 299 / 303 3.5 285.1 11.6X + + + Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 9214 / 9236 0.1 8786.7 1.0X + SQL Json 9943 / 9978 0.1 9482.7 0.9X + SQL Parquet Vectorized 77 / 86 13.6 73.3 119.8X + SQL Parquet MR 229 / 235 4.6 218.6 40.2X + SQL ORC Vectorized 84 / 96 12.5 80.0 109.9X + SQL ORC Vectorized with copy 83 / 91 12.6 79.4 110.7X + SQL ORC MR 843 / 854 1.2 804.0 10.9X + + + Single Column Scan from 100 columns Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 16503 / 16622 0.1 15738.9 1.0X + SQL Json 19109 / 19184 0.1 18224.2 0.9X + SQL Parquet Vectorized 99 / 108 10.6 94.3 166.8X + SQL Parquet MR 253 / 264 4.1 241.6 65.1X + SQL ORC Vectorized 107 / 114 9.8 101.6 154.8X + SQL ORC Vectorized with copy 107 / 118 9.8 102.1 154.1X + SQL ORC MR 1526 / 1529 0.7 1455.3 10.8X + */ + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + repeatedStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + for (columnWidth <- List(10, 50, 100)) { + columnsBenchmark(1024 * 1024 * 1, columnWidth) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala new file mode 100644 index 0000000000000..2d2cdebd067c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure data source write performance. + * By default it measures 4 data source format: Parquet, ORC, JSON, CSV: + * spark-submit --class + * To measure specified formats, run it with arguments: + * spark-submit --class format1 [format2] [...] + */ +object DataSourceWriteBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceWriteBenchmark") + .setIfMissing("spark.master", "local[1]") + .set("spark.sql.parquet.compression.codec", "snappy") + .set("spark.sql.orc.compression.codec", "snappy") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + val tempTable = "temp" + val numRows = 1024 * 1024 * 15 + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + def writeNumeric(table: String, format: String, benchmark: Benchmark, dataType: String): Unit = { + spark.sql(s"create table $table(id $dataType) using $format") + benchmark.addCase(s"Output Single $dataType Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS $dataType) AS c1 FROM $tempTable") + } + } + + def writeIntString(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 STRING) USING $format") + benchmark.addCase("Output Int and String Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS STRING) AS c2 FROM $tempTable") + } + } + + def writePartition(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(p INT, id INT) USING $format PARTITIONED BY (p)") + benchmark.addCase("Output Partitions") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS id," + + s" CAST(id % 2 AS INT) AS p FROM $tempTable") + } + } + + def writeBucket(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 INT) USING $format CLUSTERED BY (c2) INTO 2 BUCKETS") + benchmark.addCase("Output Buckets") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS INT) AS c2 FROM $tempTable") + } + } + + def main(args: Array[String]): Unit = { + val tableInt = "tableInt" + val tableDouble = "tableDouble" + val tableIntString = "tableIntString" + val tablePartition = "tablePartition" + val tableBucket = "tableBucket" + val formats: Seq[String] = if (args.isEmpty) { + Seq("Parquet", "ORC", "JSON", "CSV") + } else { + args + } + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1815 / 1932 8.7 115.4 1.0X + Output Single Double Column 1877 / 1878 8.4 119.3 1.0X + Output Int and String Column 6265 / 6543 2.5 398.3 0.3X + Output Partitions 4067 / 4457 3.9 258.6 0.4X + Output Buckets 5608 / 5820 2.8 356.6 0.3X + + ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1201 / 1239 13.1 76.3 1.0X + Output Single Double Column 1542 / 1600 10.2 98.0 0.8X + Output Int and String Column 6495 / 6580 2.4 412.9 0.2X + Output Partitions 3648 / 3842 4.3 231.9 0.3X + Output Buckets 5022 / 5145 3.1 319.3 0.2X + + JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1988 / 2093 7.9 126.4 1.0X + Output Single Double Column 2854 / 2911 5.5 181.4 0.7X + Output Int and String Column 6467 / 6653 2.4 411.1 0.3X + Output Partitions 4548 / 5055 3.5 289.1 0.4X + Output Buckets 5664 / 5765 2.8 360.1 0.4X + + CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 3025 / 3190 5.2 192.3 1.0X + Output Single Double Column 3575 / 3634 4.4 227.3 0.8X + Output Int and String Column 7313 / 7399 2.2 464.9 0.4X + Output Partitions 5105 / 5190 3.1 324.6 0.6X + Output Buckets 6986 / 6992 2.3 444.1 0.4X + */ + withTempTable(tempTable) { + spark.range(numRows).createOrReplaceTempView(tempTable) + formats.foreach { format => + withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { + val benchmark = new Benchmark(s"$format writer benchmark", numRows) + writeNumeric(tableInt, format, benchmark, "Int") + writeNumeric(tableDouble, format, benchmark, "Double") + writeIntString(tableIntString, format, benchmark) + writePartition(tablePartition, format, benchmark) + writeBucket(tableBucket, format, benchmark) + benchmark.run() + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..bdb60b44750c7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.{File, FileOutputStream, OutputStream} + +import scala.util.{Random, Try} + +import org.scalatest.{BeforeAndAfterEachTestData, Suite, TestData} + +import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure read performance with Filter pushdown. + * To run this: + * build/sbt "sql/test-only *FilterPushdownBenchmark" + * + * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". + */ +class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfterEachTest { + private val conf = new SparkConf() + .setAppName(this.getClass.getSimpleName) + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") + .setIfMissing("orc.compression", "snappy") + .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + + private val numRows = 1024 * 1024 * 15 + private val width = 5 + private val mid = numRows / 2 + private val blockSize = 1048576 + + private val spark = SparkSession.builder().config(conf).getOrCreate() + + private var out: OutputStream = _ + + override def beforeAll() { + super.beforeAll() + out = new FileOutputStream(new File("benchmarks/FilterPushdownBenchmark-results.txt")) + } + + override def beforeEach(td: TestData) { + super.beforeEach(td) + val separator = "=" * 96 + val testHeader = (separator + '\n' + td.name + '\n' + separator + '\n' + '\n').getBytes + out.write(testHeader) + } + + override def afterEach(td: TestData) { + out.write('\n') + super.afterEach(td) + } + + override def afterAll() { + try { + out.close() + } finally { + super.afterAll() + } + } + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable( + dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val valueCol = if (useStringForValue) { + monotonically_increasing_id().cast("string") + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("value", valueCol) + .sort("value") + + saveAsTable(df, dir) + } + + private def prepareStringDictTable( + dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = { + val selectExpr = (0 to width).map { + case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value" + case i => s"CAST(rand() AS STRING) c$i" + } + val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") + + saveAsTable(df, dir) + } + + private def saveAsTable(df: DataFrame, dir: File): Unit = { + val orcPath = dir.getCanonicalPath + "/orc" + val parquetPath = dir.getCanonicalPath + "/parquet" + + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite") + .option("orc.dictionary.key.threshold", 1.0) + .option("orc.stripe.size", blockSize).orc(orcPath) + spark.read.orc(orcPath).createOrReplaceTempView("orcTable") + + df.write.mode("overwrite") + .option("parquet.block.size", blockSize).parquet(parquetPath) + spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5, output = Some(out)) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + benchmark.run() + } + + private def runIntBenchmark(numRows: Int, width: Int, mid: Int): Unit = { + Seq("value IS NULL", s"$mid < value AND value < $mid").foreach { whereExpr => + val title = s"Select 0 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = $mid", + s"value <=> $mid", + s"$mid <= value AND value <= $mid", + s"${mid - 1} < value AND value < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% int rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all int rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + private def runStringBenchmark( + numRows: Int, width: Int, searchValue: Int, colType: String): Unit = { + Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") + .foreach { whereExpr => + val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = '$searchValue'", + s"value <=> '$searchValue'", + s"'$searchValue' <= value AND value <= '$searchValue'" + ).foreach { whereExpr => + val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq("value IS NOT NULL").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all $colType rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + ignore("Pushdown for many distinct value case") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + Seq(true, false).foreach { useStringForValue => + prepareTable(dir, numRows, width, useStringForValue) + if (useStringForValue) { + runStringBenchmark(numRows, width, mid, "string") + } else { + runIntBenchmark(numRows, width, mid) + } + } + } + } + } + + ignore("Pushdown for few distinct value case (use dictionary encoding)") { + withTempPath { dir => + val numDistinctValues = 200 + + withTempTable("orcTable", "patquetTable") { + prepareStringDictTable(dir, numRows, numDistinctValues, width) + runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") + } + } + } + + ignore("Pushdown benchmark for StringStartsWith") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, true) + Seq( + "value like '10%'", + "value like '1000%'", + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + + ignore(s"Pushdown benchmark for ${DecimalType.simpleString}") { + withTempPath { dir => + Seq( + s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", + s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", + s"decimal(${DecimalType.MAX_PRECISION}, 2)" + ).foreach { dt => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { + monotonically_increasing_id() % 9999999 + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = $mid").foreach { whereExpr => + val title = s"Select 1 $dt row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% $dt rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + } + } + } + } + + ignore("Pushdown benchmark for InSet -> InFilters") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, false) + Seq(5, 10, 50, 100).foreach { count => + Seq(10, 50, 90).foreach { distribution => + val filter = + Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) + val whereExpr = s"value in(${filter.mkString(",")})" + val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + } + + ignore(s"Pushdown benchmark for ${ByteType.simpleString}") { + withTempPath { dir => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) + .orderBy("value") + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") + .foreach { whereExpr => + val title = s"Select 1 ${ByteType.simpleString} row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% ${ByteType.simpleString} rows " + + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", + s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", + selectExpr + ) + } + } + } + } + + ignore(s"Pushdown benchmark for Timestamp") { + withTempPath { dir => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> true.toString) { + ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(TimestampType)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => + val title = s"Select 1 timestamp stored as $fileType row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% timestamp stored as $fileType rows " + + s"(value < CAST(${numRows * percent / 100} AS timestamp))", + s"value < CAST(${numRows * percent / 100} as timestamp)", + selectExpr + ) + } + } + } + } + } + } + } +} + +trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => + + override def beforeEach(td: TestData) { + super.beforeEach(td) + } + + override def afterEach(td: TestData) { + super.afterEach(td) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 69247d7f4e9aa..fccee97820e75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -58,10 +58,13 @@ object TPCDSQueryBenchmark extends Logging { }.toMap } - def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - val tableSizes = setupTables(dataLocation) + def runTpcdsQueries( + queryLocation: String, + queries: Seq[String], + tableSizes: Map[String, Long], + nameSuffix: String = ""): Unit = { queries.foreach { name => - val queryString = resourceToString(s"tpcds/$name.sql", + val queryString = resourceToString(s"$queryLocation/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the @@ -69,7 +72,7 @@ object TPCDSQueryBenchmark extends Logging { val queryRelations = scala.collection.mutable.HashSet[String]() spark.sql(queryString).queryExecution.analyzed.foreach { case SubqueryAlias(alias, _: LogicalRelation) => - queryRelations.add(alias) + queryRelations.add(alias.identifier) case LogicalRelation(_, _, Some(catalogTable), _) => queryRelations.add(catalogTable.identifier.table) case HiveTableRelation(tableMeta, _, _) => @@ -78,7 +81,7 @@ object TPCDSQueryBenchmark extends Logging { } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5) - benchmark.addCase(name) { i => + benchmark.addCase(s"$name$nameSuffix") { _ => spark.sql(queryString).collect() } logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") @@ -87,10 +90,20 @@ object TPCDSQueryBenchmark extends Logging { } } + def filterQueries( + origQueries: Seq[String], + args: TPCDSQueryBenchmarkArguments): Seq[String] = { + if (args.queryFilter.nonEmpty) { + origQueries.filter(args.queryFilter.contains) + } else { + origQueries + } + } + def main(args: Array[String]): Unit = { val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) - // List of all TPC-DS queries + // List of all TPC-DS v1.4 queries val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -103,20 +116,25 @@ object TPCDSQueryBenchmark extends Logging { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + // This list only includes TPC-DS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + // If `--query-filter` defined, filters the queries that this option selects - val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { - val queries = tpcdsQueries.filter { case queryName => - benchmarkArgs.queryFilter.contains(queryName) - } - if (queries.isEmpty) { - throw new RuntimeException( - s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") - } - queries - } else { - tpcdsQueries + val queriesV1_4ToRun = filterQueries(tpcdsQueries, benchmarkArgs) + val queriesV2_7ToRun = filterQueries(tpcdsQueriesV2_7, benchmarkArgs) + + if ((queriesV1_4ToRun ++ queriesV2_7ToRun).isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") } - tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) + val tableSizes = setupTables(benchmarkArgs.dataLocation) + runTpcdsQueries(queryLocation = "tpcds", queries = queriesV1_4ToRun, tableSizes) + runTpcdsQueries(queryLocation = "tpcds-v2.7.0", queries = queriesV2_7ToRun, tableSizes, + nameSuffix = "-v2.7") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 9b7b316211d30..efc2f20a907f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, data.logicalPlan) - assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) - inMemoryRelation.cachedColumnBuffers.collect().head match { + assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel) + inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match { case _: CachedBatch => case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") } @@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { @@ -503,7 +503,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized - assert(inMemoryRelation.computeStats().sizeInBytes === 740) + assert(inMemoryRelation.computeStats().sizeInBytes === 800) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -516,7 +516,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated - assert(inMemoryRelation2.computeStats().sizeInBytes === 740) + assert(inMemoryRelation2.computeStats().sizeInBytes === 800) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 9d862cfdecb21..af493e93b5192 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -35,6 +36,12 @@ class PartitionBatchPruningSuite private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) private lazy val originalInMemoryPartitionPruning = spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) + private val testArrayData = (1 to 100).map { key => + Tuple1(Array.fill(key)(key)) + } + private val testBinaryData = (1 to 100).map { key => + Tuple1(Array.fill(key)(key.toByte)) + } override protected def beforeAll(): Unit = { super.beforeAll() @@ -71,12 +78,22 @@ class PartitionBatchPruningSuite }, 5).toDF() pruningStringData.createOrReplaceTempView("pruningStringData") spark.catalog.cacheTable("pruningStringData") + + val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF() + pruningArrayData.createOrReplaceTempView("pruningArrayData") + spark.catalog.cacheTable("pruningArrayData") + + val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF() + pruningBinaryData.createOrReplaceTempView("pruningBinaryData") + spark.catalog.cacheTable("pruningBinaryData") } override protected def afterEach(): Unit = { try { spark.catalog.uncacheTable("pruningData") spark.catalog.uncacheTable("pruningStringData") + spark.catalog.uncacheTable("pruningArrayData") + spark.catalog.uncacheTable("pruningBinaryData") } finally { super.afterEach() } @@ -95,6 +112,14 @@ class PartitionBatchPruningSuite checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11) checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100) checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100) + // Do not filter on array type + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1))) + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1))) + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)( + testArrayData.map(_._1)) + // Do not filter on binary type + checkBatchPruning( + "SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte))) // IS NULL checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) { @@ -131,6 +156,9 @@ class PartitionBatchPruningSuite checkBatchPruning( "SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)( Seq(150)) + // Do not filter on array type + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)( + Seq(Array(1), Array(2, 2))) // With unsupported `InSet` predicate { @@ -161,7 +189,7 @@ class PartitionBatchPruningSuite query: String, expectedReadPartitions: Int, expectedReadBatches: Int)( - expectedQueryResult: => Seq[Int]): Unit = { + expectedQueryResult: => Seq[Any]): Unit = { test(query) { val df = sql(query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3998ceca38b30..f8d98dead2d42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -52,23 +52,24 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo protected override def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): CatalogTable = { + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() .putString("key", "value") .build() + val schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") CatalogTable( identifier = name, tableType = CatalogTableType.EXTERNAL, storage = storage, - schema = new StructType() - .add("col1", "int", nullable = true, metadata = metadata) - .add("col2", "string") - .add("a", "int") - .add("b", "int"), + schema = schema.copy( + fields = schema.fields ++ partitionCols.map(StructField(_, IntegerType))), provider = Some("parquet"), - partitionColumnNames = Seq("a", "b"), + partitionColumnNames = partitionCols, createTime = 0L, createVersion = org.apache.spark.SPARK_VERSION, tracksPartitionsInCatalog = true) @@ -176,7 +177,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { protected def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): CatalogTable + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -228,8 +230,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { private def createTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): Unit = { - catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): Unit = { + catalog.createTable( + generateTable(catalog, name, isDataSource, partitionCols), ignoreIfExists = false) } private def createTablePartition( @@ -441,6 +445,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename a managed table with existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab2"))) + try { + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING $dataSource AS SELECT 1, 'a'") + tableLoc.mkdir() + val ex = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO tab2") + }.getMessage + val expectedMsg = "Can not rename the managed table('`tab1`'). The associated location" + assert(ex.contains(expectedMsg)) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], @@ -1113,7 +1135,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("alter table: recover partition (parallel)") { - withSQLConf("spark.rdd.parallelListingThreshold" -> "1") { + withSQLConf("spark.rdd.parallelListingThreshold" -> "0") { testRecoverPartitions() } } @@ -1126,23 +1148,32 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } val tableIdent = TableIdentifier("tab1") - createTable(catalog, tableIdent) - val part1 = Map("a" -> "1", "b" -> "5") + createTable(catalog, tableIdent, partitionCols = Seq("a", "b", "c")) + val part1 = Map("a" -> "1", "b" -> "5", "c" -> "19") createTablePartition(catalog, part1, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) - val part2 = Map("a" -> "2", "b" -> "6") + val part2 = Map("a" -> "2", "b" -> "6", "c" -> "31") val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid - fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) - fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file - fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "_SUCCESS")) // file - fs.mkdirs(new Path(new Path(root, "A=2"), "B=6")) - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "b.csv")) // file - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "c.csv")) // file - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), ".hiddenFile")) // file - fs.mkdirs(new Path(new Path(root, "A=2/B=6"), "_temporary")) + fs.mkdirs(new Path(new Path(new Path(root, "a=1"), "b=5"), "c=19")) + fs.createNewFile(new Path(new Path(root, "a=1/b=5/c=19"), "a.csv")) // file + fs.createNewFile(new Path(new Path(root, "a=1/b=5/c=19"), "_SUCCESS")) // file + + fs.mkdirs(new Path(new Path(new Path(root, "A=2"), "B=6"), "C=31")) + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), "b.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), "c.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), ".hiddenFile")) // file + fs.mkdirs(new Path(new Path(root, "A=2/B=6/C=31"), "_temporary")) + + val parts = (10 to 100).map { a => + val part = Map("a" -> a.toString, "b" -> "5", "c" -> "42") + fs.mkdirs(new Path(new Path(new Path(root, s"a=$a"), "b=5"), "c=42")) + fs.createNewFile(new Path(new Path(root, s"a=$a/b=5/c=42"), "a.csv")) // file + createTablePartition(catalog, part, tableIdent) + part + } // invalid fs.mkdirs(new Path(new Path(root, "a"), "b")) // bad name @@ -1156,7 +1187,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { try { sql("ALTER TABLE tab1 RECOVER PARTITIONS") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2)) + Set(part1, part2) ++ parts) if (!isUsingHiveMetastore) { assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") @@ -2231,6 +2262,68 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("Partition table should load empty static partitions") { + // All static partitions + withTable("t", "t1", "t2") { + withTempPath { dir => + spark.sql("CREATE TABLE t(a int) USING parquet") + spark.sql("CREATE TABLE t1(a int, c string, b string) " + + s"USING parquet PARTITIONED BY(c, b) LOCATION '${dir.toURI}'") + + // datasource table + validateStaticPartitionTable("t1") + + // hive table + if (isUsingHiveMetastore) { + spark.sql("CREATE TABLE t2(a int) " + + s"PARTITIONED BY(c string, b string) LOCATION '${dir.toURI}'") + validateStaticPartitionTable("t2") + } + + def validateStaticPartitionTable(tableName: String): Unit = { + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + spark.sql( + s"INSERT INTO TABLE $tableName PARTITION(b='b', c='c') SELECT * FROM t WHERE 1 = 0") + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 1) + assert(new File(dir, "c=c/b=b").exists()) + checkAnswer(spark.table(tableName), Nil) + } + } + } + + // Partial dynamic partitions + withTable("t", "t1", "t2") { + withTempPath { dir => + spark.sql("CREATE TABLE t(a int) USING parquet") + spark.sql("CREATE TABLE t1(a int, b string, c string) " + + s"USING parquet PARTITIONED BY(c, b) LOCATION '${dir.toURI}'") + + // datasource table + validatePartialStaticPartitionTable("t1") + + // hive table + if (isUsingHiveMetastore) { + spark.sql("CREATE TABLE t2(a int) " + + s"PARTITIONED BY(c string, b string) LOCATION '${dir.toURI}'") + validatePartialStaticPartitionTable("t2") + } + + def validatePartialStaticPartitionTable(tableName: String): Unit = { + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + spark.sql( + s"INSERT INTO TABLE $tableName PARTITION(c='c', b) SELECT *, 'b' FROM t WHERE 1 = 0") + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + assert(!new File(dir, "c=c/b=b").exists()) + checkAnswer(spark.table(tableName), Nil) + } + } + } + } + Seq(true, false).foreach { shouldDelete => val tcName = if (shouldDelete) "non-existing" else "existed" test(s"CTAS for external data source table with a $tcName location") { @@ -2495,7 +2588,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("alter datasource table add columns - text format not supported") { withTable("t1") { - sql("CREATE TABLE t1 (c1 int) USING text") + sql("CREATE TABLE t1 (c1 string) USING text") val e = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8764f0c42cf9f..bceaf1a9ec061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala new file mode 100644 index 0000000000000..508614a7e476c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext + +class HadoopFileLinesReaderSuite extends SharedSQLContext { + def getLines( + path: File, + text: String, + ranges: Seq[(Long, Long)], + delimiter: Option[String] = None, + conf: Option[Configuration] = None): Seq[String] = { + val delimOpt = delimiter.map(_.getBytes(StandardCharsets.UTF_8)) + Files.write(path.toPath, text.getBytes(StandardCharsets.UTF_8)) + + val lines = ranges.map { case (start, length) => + val file = PartitionedFile(InternalRow.empty, path.getCanonicalPath, start, length) + val hadoopConf = conf.getOrElse(spark.sessionState.newHadoopConf()) + val reader = new HadoopFileLinesReader(file, delimOpt, hadoopConf) + + reader.map(_.toString) + }.flatten + + lines + } + + test("A split ends at the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 1), (1, 3))) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 2), (2, 2))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the end of the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 3), (3, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split covers two lines") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 4), (4, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 1), (1, 4)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split slices the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 2), (2, 3)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("The first split covers the first line and the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 4), (4, 1)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the first line") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((0, 1)), Some(",")) + assert(lines == Seq("abc")) + } + } + + test("The split cuts both lines") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((2, 2)), Some(",")) + assert(lines == Seq("def")) + } + } + + test("io.file.buffer.size is less than line length") { + withSQLConf("io.file.buffer.size" -> "2") { + withTempPath { path => + val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) + assert(lines == Seq("123456")) + } + } + } + + test("line cannot be longer than line.maxlength") { + withSQLConf("mapreduce.input.linerecordreader.line.maxlength" -> "5") { + withTempPath { path => + val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) + assert(lines == Seq("1234")) + } + } + } + + test("default delimiter is 0xd or 0xa or 0xd0xa") { + withTempPath { path => + val lines = getLines(path, text = "1\r2\n3\r\n4", ranges = Seq((0, 3), (3, 5))) + assert(lines == Seq("1", "2", "3", "4")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala new file mode 100644 index 0000000000000..23c58e175fe5e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.internal.SQLConf + +/** + * Read schema suites have the following hierarchy and aims to guarantee users + * a backward-compatible read-schema change coverage on file-based data sources, and + * to prevent future regressions. + * + * ReadSchemaSuite + * -> CSVReadSchemaSuite + * -> HeaderCSVReadSchemaSuite + * + * -> JsonReadSchemaSuite + * + * -> OrcReadSchemaSuite + * -> VectorizedOrcReadSchemaSuite + * + * -> ParquetReadSchemaSuite + * -> VectorizedParquetReadSchemaSuite + * -> MergedParquetReadSchemaSuite + */ + +/** + * All file-based data sources supports column addition and removal at the end. + */ +abstract class ReadSchemaSuite + extends AddColumnTest + with HideColumnAtTheEndTest { + + var originalConf: Boolean = _ +} + +class CSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" +} + +class HeaderCSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" + + override val options = Map("header" -> "true") +} + +class JsonReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "json" +} + +class OrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedOrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with BooleanTypeTest + with IntegralTypeTest + with ToDoubleTypeTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class ParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class MergedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED) + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, originalConf) + super.afterAll() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala new file mode 100644 index 0000000000000..2a5457e00b4ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +/** + * The reader schema is said to be evolved (or projected) when it changed after the data is + * written by writers. The followings are supported in file-based data sources. + * Note that partition columns are not maintained in files. Here, `column` means non-partition + * column. + * + * 1. Add a column + * 2. Hide a column + * 3. Change a column position + * 4. Change a column type (Upcast) + * + * Here, we consider safe changes without data loss. For example, data type changes should be + * from small types to larger types like `int`-to-`long`, not vice versa. + * + * So far, file-based data sources have the following coverages. + * + * | File Format | Coverage | Note | + * | ------------ | ------------ | ------------------------------------------------------ | + * | TEXT | N/A | Schema consists of a single string column. | + * | CSV | 1, 2, 4 | | + * | JSON | 1, 2, 3, 4 | | + * | ORC | 1, 2, 3, 4 | Native vectorized ORC reader has the widest coverage. | + * | PARQUET | 1, 2, 3 | | + * + * This aims to provide an explicit test coverage for reader schema change on file-based data + * sources. Since a file format has its own coverage, we need a test suite for each file-based + * data source with corresponding supported test case traits. + * + * The following is a hierarchy of test traits. + * + * ReadSchemaTest + * -> AddColumnTest + * -> HideColumnTest + * -> ChangePositionTest + * -> BooleanTypeTest + * -> IntegralTypeTest + * -> ToDoubleTypeTest + * -> ToDecimalTypeTest + */ + +trait ReadSchemaTest extends QueryTest with SQLTestUtils with SharedSQLContext { + val format: String + val options: Map[String, String] = Map.empty[String, String] +} + +/** + * Add column (Case 1). + * This test suite assumes that the missing column should be `null`. + */ +trait AddColumnTest extends ReadSchemaTest { + import testImplicits._ + + test("append column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq("a", "b").toDF("col1") + val df2 = df1.withColumn("col2", lit("x")) + val df3 = df2.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + val dir3 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + df3.write.format(format).options(options).save(dir3) + + val df = spark.read + .schema(df3.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", null, null, "one"), + Row("b", null, null, "one"), + Row("a", "x", null, "two"), + Row("b", "x", null, "two"), + Row("a", "x", "y", "three"), + Row("b", "x", "y", "three"))) + } + } +} + +/** + * Hide column (Case 2-1). + */ +trait HideColumnAtTheEndTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(df1.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("1", "a", "two"), + Row("2", "b", "two"), + Row("1", "a", "three"), + Row("2", "b", "three"))) + + val df3 = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df3, Seq( + Row("1", "two"), + Row("2", "two"), + Row("1", "three"), + Row("2", "three"))) + } + } +} + +/** + * Hide column in the middle (Case 2-2). + */ +trait HideColumnInTheMiddleTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column in the middle") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema("col2 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", "two"), + Row("b", "two"), + Row("a", "three"), + Row("b", "three"))) + } + } +} + +/** + * Change column positions (Case 3). + * This suite assumes that all data set have the same number of columns. + */ +trait ChangePositionTest extends ReadSchemaTest { + import testImplicits._ + + test("change column position") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b"), ("3", "c")).toDF("col1", "col2") + val df2 = Seq(("d", "4"), ("e", "5"), ("f", "6")).toDF("col2", "col1") + val unionDF = df1.unionByName(df2) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1", "col2") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait BooleanTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("change column type from boolean to byte/short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val values = (1 to 10).map(_ % 2) + val booleanDF = (1 to 10).map(_ % 2 == 1).toDF("col1") + val byteDF = values.map(_.toByte).toDF("col1") + val shortDF = values.map(_.toShort).toDF("col1") + val intDF = values.toDF("col1") + val longDF = values.map(_.toLong).toDF("col1") + + booleanDF.write.mode("overwrite").format(format).options(options).save(path) + + Seq( + ("col1 byte", byteDF), + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToStringTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("read as string") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + .selectExpr("cast(col1 AS STRING) col1") + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait IntegralTypeTest extends ReadSchemaTest { + + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val byteDF = values.map(_.toByte).toDF("col1") + private lazy val shortDF = values.map(_.toShort).toDF("col1") + private lazy val intDF = values.toDF("col1") + private lazy val longDF = values.map(_.toLong).toDF("col1") + + test("change column type from byte to short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + byteDF.write.format(format).options(options).save(path) + + Seq( + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from short to int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + shortDF.write.format(format).options(options).save(path) + + Seq(("col1 int", intDF), ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from int to long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + intDF.write.format(format).options(options).save(path) + + Seq(("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("read byte, int, short, long together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDoubleTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF) + + test("change column type from float to double") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read.schema("col1 double").format(format).options(options).load(path) + + checkAnswer(df, doubleDF) + } + } + + test("read float and double together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDecimalTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val decimalDF = values.map(BigDecimal(_)).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF).union(decimalDF) + + test("change column type from float to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("change column type from double to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + doubleDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("read float, double, decimal together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + val decimalDir = s"$path${File.separator}part=decimal" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + decimalDF.write.format(format).options(options).save(decimalDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala index 4b3ca8e60cab6..a1da3ec43eae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -23,9 +23,6 @@ import org.apache.spark.sql.test.SharedSQLContext class SaveIntoDataSourceCommandSuite extends SharedSQLContext { - override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.redaction.regex", "(?i)password|url") - test("simpleString is redacted") { val URL = "connection.url" val PASS = "123" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala new file mode 100644 index 0000000000000..24f5f55d55485 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure CSV read/write performance. + * To run this: + * spark-submit --class --jars + */ +object CSVBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-csv-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) + + withTempPath { path => + val str = (0 until 10000).map(i => s""""$i"""").mkString(",") + + spark.range(rowsNum) + .map(_ => str) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val schema = new StructType().add("value", StringType) + val ds = spark.read.option("header", true).schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"One quoted string", numIters) { _ => + ds.filter((_: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + One quoted string 30273 / 30549 0.0 605451.2 1.0X + */ + benchmark.run() + } + } + + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 81091 / 81692 0.0 81090.7 1.0X + Select 100 columns 30003 / 34448 0.0 30003.0 2.7X + Select one column 24792 / 24855 0.0 24792.0 3.3X + count() 24344 / 24642 0.0 24343.8 3.3X + */ + benchmark.run() + } + } + + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 12598 / 12740 0.8 1259.8 1.0X + Select 1 column + count() 7960 / 8175 1.3 796.0 1.6X + count() 2332 / 2386 4.3 233.2 5.4X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) + countBenchmark(10 * 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 661742087112f..57e36e082653c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { test("String fields types are inferred correctly from null types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(NullType, "", options) == NullType) assert(CSVInferSchema.inferField(NullType, null, options) == NullType) assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) @@ -41,7 +41,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("String fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) @@ -60,21 +60,21 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Timestamp field types are inferred correctly via custom data format") { - var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), "GMT") + var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - options = new CSVOptions(Map("timestampFormat" -> "yyyy"), "GMT") + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } test("Timestamp field types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) } test("Boolean fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) } @@ -92,12 +92,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Null fields are handled properly when a nullValue is specified") { - var options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) - options = new CSVOptions(Map("nullValue" -> "\\N"), "GMT") + options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) @@ -111,12 +111,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), "GMT") + val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) } test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == @@ -132,9 +132,9 @@ class CSVInferSchemaSuite extends SparkFunSuite { == StringType) } - test("DoubleType should be infered when user defined nan/inf are provided") { + test("DoubleType should be inferred when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", - "positiveInf" -> "inf"), "GMT") + "positiveInf" -> "inf"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4398e547d9217..5a1d6679ebbdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -18,24 +18,29 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File -import java.nio.charset.UnsupportedCharsetException +import java.nio.charset.{Charset, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Properties + import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.log4j.{AppenderSkeleton, LogManager} +import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with TestCsvData { import testImplicits._ private val carsFile = "test-data/cars.csv" @@ -57,10 +62,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" - private def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - /** Verifies data and schema. */ private def verifyCars( df: DataFrame, @@ -261,14 +262,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -513,6 +516,41 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + test("SPARK-19018: Save csv with custom charset") { + + // scalastyle:off nonascii + val content = "µß áâä ÁÂÄ" + // scalastyle:on nonascii + + Seq("iso-8859-1", "utf-8", "utf-16", "utf-32", "windows-1250").foreach { encoding => + withTempPath { path => + val csvDir = new File(path, "csv") + Seq(content).toDF().write + .option("encoding", encoding) + .csv(csvDir.getCanonicalPath) + + csvDir.listFiles().filter(_.getName.endsWith("csv")).foreach({ csvFile => + val readback = Files.readAllBytes(csvFile.toPath) + val expected = (content + Properties.lineSeparator).getBytes(Charset.forName(encoding)) + assert(readback === expected) + }) + } + } + } + + test("SPARK-19018: error handling for unsupported charsets") { + val exception = intercept[SparkException] { + withTempPath { path => + val csvDir = new File(path, "csv").getCanonicalPath + Seq("a,A,c,A,b,B").toDF().write + .option("encoding", "1-9588-osi") + .csv(csvDir) + } + } + + assert(exception.getCause.getMessage.contains("1-9588-osi")) + } + test("commented lines in CSV data") { Seq("false", "true").foreach { multiLine => @@ -735,39 +773,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(numbers.count() == 8) } - test("error handling for unsupported data types.") { - withTempDir { dir => - val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { - Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support struct data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support map data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") - .write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) - spark.range(1).write.csv(csvDir) - spark.read.schema(schema).csv(csvDir).collect() - }.getMessage - assert(msg.contains("CSV data source does not support array data type.")) - } - } - test("SPARK-15585 turn off quotations") { val cars = spark.read .format("csv") @@ -1279,4 +1284,420 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + test("SPARK-23846: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(path.getCanonicalPath) + assert(readback.schema == new StructType().add("_c0", IntegerType)) + }) + } + + test("SPARK-23846: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(ds) + + assert(readback.schema == new StructType().add("_c0", IntegerType)) + } + + test("SPARK-23846: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", -1).csv(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", 0).csv(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) + assert(sampled.count() == ds.count()) + } + + test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") { + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where an empty string is not coerced to null when `nullValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("nullValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("nullValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, ""), + (3, litNull), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + // Keeps the old behavior where empty string us coerced to nullValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + } + + test("SPARK-24329: skip lines with comments, and one or multiple whitespaces") { + val schema = new StructType().add("colA", StringType) + val ds = spark + .read + .schema(schema) + .option("multiLine", false) + .option("header", true) + .option("comment", "#") + .option("ignoreLeadingWhiteSpace", false) + .option("ignoreTrailingWhiteSpace", false) + .csv(testFile("test-data/comments-whitespaces.csv")) + + checkAnswer(ds, Seq(Row(""" "a" """))) + } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + assert(idf.count() == 2) + checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) + } + } + + def checkHeader(multiLine: Boolean): Unit = { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val exception = intercept[SparkException] { + spark.read + .schema(ischema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + + val shortSchema = new StructType().add("f1", DoubleType) + val exceptionForShortSchema = intercept[SparkException] { + spark.read + .schema(shortSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForShortSchema.getMessage.contains( + "Number of column in CSV header is not equal to number of fields in the schema")) + + val longSchema = new StructType() + .add("f1", DoubleType) + .add("f2", DoubleType) + .add("f3", DoubleType) + + val exceptionForLongSchema = intercept[SparkException] { + spark.read + .schema(longSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForLongSchema.getMessage.contains("Header length: 2, schema size: 3")) + + val caseSensitiveSchema = new StructType().add("F1", DoubleType).add("f2", DoubleType) + val caseSensitiveException = intercept[SparkException] { + spark.read + .schema(caseSensitiveSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(caseSensitiveException.getMessage.contains( + "CSV header does not conform to the schema")) + } + } + } + + test(s"SPARK-23786: Checking column names against schema in the multiline mode") { + checkHeader(multiLine = true) + } + + test(s"SPARK-23786: Checking column names against schema in the per-line mode") { + checkHeader(multiLine = false) + } + + test("SPARK-23786: CSV header must not be checked if it doesn't exist") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", false).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val idf = spark.read + .schema(ischema) + .option("header", false) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + + checkAnswer(idf, odf) + } + } + + test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val oschema = new StructType().add("A", StringType) + val odf = spark.createDataFrame(List(Row("0")).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("a", StringType) + val idf = spark.read.schema(ischema) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + checkAnswer(idf, odf) + } + } + } + + test("SPARK-23786: check header on parsing of dataset of strings") { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + val exception = intercept[IllegalArgumentException] { + spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds) + } + + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: enforce inferred schema") { + val expectedSchema = new StructType().add("_c0", DoubleType).add("_c1", StringType) + val withHeader = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .csv(Seq("_c0,_c1", "1.0,a").toDS()) + assert(withHeader.schema == expectedSchema) + checkAnswer(withHeader, Seq(Row(1.0, "a"))) + + // Ignore the inferSchema flag if an user sets a schema + val schema = new StructType().add("colA", DoubleType).add("colB", StringType) + val ds = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("colA,colB", "1.0,a").toDS()) + assert(ds.schema == schema) + checkAnswer(ds, Seq(Row(1.0, "a"))) + + val exception = intercept[IllegalArgumentException] { + spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("col1,col2", "1.0,a").toDS()) + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") { + class TestAppender extends AppenderSkeleton { + var events = new java.util.ArrayList[LoggingEvent] + override def close(): Unit = {} + override def requiresLayout: Boolean = false + protected def append(event: LoggingEvent): Unit = events.add(event) + } + + val testAppender1 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender1) + try { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + + spark.read.schema(ischema).option("header", true).option("enforceSchema", true).csv(ds) + } finally { + LogManager.getRootLogger.removeAppender(testAppender1) + } + assert(testAppender1.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + + val testAppender2 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender2) + try { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + spark.read + .schema(ischema) + .option("header", true) + .option("enforceSchema", true) + .csv(path.getCanonicalPath) + .collect() + } + } finally { + LogManager.getRootLogger.removeAppender(testAppender2) + } + assert(testAppender2.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + } + + test("SPARK-25134: check header on parsing of dataset with projection and column pruning") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + Seq(false, true).foreach { multiLine => + withTempPath { path => + val dir = path.getAbsolutePath + Seq(("a", "b")).toDF("columnA", "columnB").write + .format("csv") + .option("header", true) + .save(dir) + + // schema with one column + checkAnswer(spark.read + .format("csv") + .option("header", true) + .option("enforceSchema", false) + .option("multiLine", multiLine) + .load(dir) + .select("columnA"), + Row("a")) + + // empty schema + assert(spark.read + .format("csv") + .option("header", true) + .option("enforceSchema", false) + .option("multiLine", multiLine) + .load(dir) + .count() === 1L) + } + } + } + } + + test("SPARK-24645 skip parsing when columnPruning enabled and partitions scanned only") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id").write.partitionBy("p").csv(dir) + checkAnswer(spark.read.csv(dir).selectExpr("sum(p)"), Row(5)) + } + } + } + + test("SPARK-24676 project required data from parsed data when columnPruning disabled") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id AS c0", "id AS c1").write.partitionBy("p") + .option("header", "true").csv(dir) + val df1 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)", "count(c0)") + checkAnswer(df1, Row(5, 10)) + + // empty required column case + val df2 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)") + checkAnswer(df2, Row(5)) + } + + // the case where tokens length != parsedSchema length + withTempPath { path => + val dir = path.getAbsolutePath + Seq("1,2").toDF().write.text(dir) + // more tokens + val df1 = spark.read.schema("c0 int").format("csv").option("mode", "permissive").load(dir) + checkAnswer(df1, Row(1)) + // less tokens + val df2 = spark.read.schema("c0 int, c1 int, c2 int").format("csv") + .option("mode", "permissive").load(dir) + checkAnswer(df2, Row(1, 2, null)) + } + } + } + + test("count() for malformed input") { + def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", IntegerType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).option("header", false).csv(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = "1" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("0xAC", validRec), + Seq(validRec, "0.314"), + Seq("\\\\\\", validRec) + ) + inputs.foreach { input => + countForMalformedCSV(expected, input) + } + } + + checkCount(2) + countForMalformedCSV(0, Seq("")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala new file mode 100644 index 0000000000000..3e20cc47dca2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} + +private[csv] trait TestCsvData { + protected def spark: SparkSession + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + index.toString + } else { + (index.toDouble + 0.1).toString + } + }(Encoders.STRING) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index efbf73534bd19..458edb253fb33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParserSuite extends SparkFunSuite { - private val parser = - new UnivocityParser(StructType(Seq.empty), new CSVOptions(Map.empty[String, String], "GMT")) + private val parser = new UnivocityParser( + StructType(Seq.empty), + new CSVOptions(Map.empty[String, String], false, "GMT")) private def assertNull(v: Any) = assert(v == null) @@ -38,7 +39,7 @@ class UnivocityParserSuite extends SparkFunSuite { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } @@ -51,21 +52,21 @@ class UnivocityParserSuite extends SparkFunSuite { // Nullable field with nullValue option. types.foreach { t => // Tests that a custom nullValue. - val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) assertNull(converter.apply("-")) assertNull(converter.apply(null)) // Tests that the default nullValue is empty string. - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) } // Not nullable field with nullValue option. types.foreach { t => // Casts a null to not nullable field should throw an exception. - val options = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = false, options = options) var message = intercept[RuntimeException] { @@ -81,7 +82,7 @@ class UnivocityParserSuite extends SparkFunSuite { // If nullValue is different with empty string, then, empty string should not be casted into // null. Seq(true, false).foreach { b => - val options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") val converter = parser.makeConverter("_1", StringType, nullable = b, options = options) assert(converter.apply("") == UTF8String.fromString("")) @@ -89,7 +90,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for empty string with non null type") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val exception = intercept[RuntimeException]{ parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") } @@ -97,7 +98,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Types are cast correctly") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) @@ -107,7 +108,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) val timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), "GMT") + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") val customTimestamp = "31/01/2015 00:00" val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime val castedTimestamp = @@ -116,7 +117,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(castedTimestamp == expectedTime * 1000L) val customDate = "31/01/2015" - val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), "GMT") + val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") val expectedDate = dateOptions.dateFormat.parse(customDate).getTime val castedDate = parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) @@ -131,7 +132,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for casting an invalid string to Float and Double Types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val types = Seq(DoubleType, FloatType) val input = Seq("10u000", "abc", "1 2/3") types.foreach { dt => @@ -145,7 +146,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "nn"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") val floatVal: Float = parser.makeConverter( "_1", FloatType, nullable = true, options = options ).apply("nn").asInstanceOf[Float] @@ -156,7 +157,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") val doubleVal: Double = parser.makeConverter( "_1", DoubleType, nullable = true, options = options ).apply("-").asInstanceOf[Double] @@ -165,14 +166,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val floatVal1 = parser.makeConverter( "_1", FloatType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val floatVal2 = parser.makeConverter( "_1", FloatType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Float] @@ -181,14 +182,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val doubleVal1 = parser.makeConverter( "_1", DoubleType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val doubleVal2 = parser.makeConverter( "_1", DoubleType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Double] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala new file mode 100644 index 0000000000000..a2b747eaab411 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.json + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + +/** + * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't. + * To run this: + * spark-submit --class --jars + */ +object JSONBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-json-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + + def schemaInferring(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON schema inferring", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + + benchmark.addCase("No encoding", 3) { _ => + spark.read.json(path.getAbsolutePath) + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 38902 / 39282 2.6 389.0 1.0X + UTF-8 is set 56959 / 57261 1.8 569.6 0.7X + */ + benchmark.run() + } + } + + def perlineParsing(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON per-line parsing", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write.json(path.getAbsolutePath) + val schema = new StructType().add("fieldA", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 25947 / 26188 3.9 259.5 1.0X + UTF-8 is set 46319 / 46417 2.2 463.2 0.6X + */ + benchmark.run() + } + } + + def perlineParsingOfWideColumn(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map { i => + val s = "abcdef0123456789ABCDEF" * 20 + s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" + } + .toDF().write.text(path.getAbsolutePath) + val schema = new StructType() + .add("a", StringType).add("b", LongType) + .add("c", StringType).add("d", LongType) + .add("e", StringType).add("f", LongType) + .add("x", StringType).add("y", LongType) + .add("z", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 45543 / 45660 0.2 4554.3 1.0X + UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X + */ + benchmark.run() + } + } + + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .json(path.getAbsolutePath) + + val ds = spark.read.schema(schema).json(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 9961 / 10006 1.0 996.1 1.0X + Select 1 column + count() 8355 / 8470 1.2 835.5 1.2X + count() 2104 / 2156 4.8 210.4 4.7X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + schemaInferring(100 * 1000 * 1000) + perlineParsing(100 * 1000 * 1000) + perlineParsingOfWideColumn(10 * 1000 * 1000) + countBenchmark(10 * 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 70aee561ff0f6..3e4cc8f166279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, StringWriter} -import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths, StandardOpenOption} +import java.io._ +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -31,11 +31,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -2128,38 +2128,391 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } - test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) + test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + + assert(readback.schema == new StructType().add("f1", LongType)) + }) + } + + test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read.option("samplingRatio", 0.1).json(ds) + + assert(readback.schema == new StructType().add("f1", LongType)) + } + + test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", -1).json(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", 0).json(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("samplingRatio", 1.0).json(ds) + assert(sampled.count() == ds.count()) + } + + test("SPARK-23723: json in UTF-16 with BOM") { + val fileName = "test-data/utf16WithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("encoding", "UTF-16") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"), Row("Doug", "Rood"))) + } + + test("SPARK-23723: multi-line json in UTF-32BE with BOM") { + val fileName = "test-data/utf32BEWithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16LE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Unsupported encoding name") { + val invalidCharset = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + spark.read + .options(Map("encoding" -> invalidCharset, "lineSep" -> "\n")) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains(invalidCharset)) + } + + test("SPARK-23723: checking that the encoding option is case agnostic") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "uTf-16lE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: specified encoding is not matched to actual encoding") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val exception = intercept[SparkException] { + spark.read.schema(schema) + .option("mode", "FAILFAST") + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16BE")) + .json(testFile(fileName)) + .count() + } + val errMsg = exception.getMessage + + assert(errMsg.contains("Malformed records are detected in record parsing")) + } + + def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, + expectedContent: String): Unit = { + val jsonFiles = new File(pathToJsonFiles) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("json")) + val actualContent = jsonFiles.map { file => + new String(Files.readAllBytes(file.toPath), expectedEncoding) + }.mkString.trim + + assert(actualContent == expectedContent) + } + + test("SPARK-23723: save json in UTF-32BE") { + val encoding = "UTF-32BE" withTempPath { path => - val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), - StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) - for (i <- 0 until 100) { - if (predefinedSample.contains(i)) { - writer.write(s"""{"f1":${i.toString}}""" + "\n") + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .options(Map("encoding" -> encoding)) + .json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = encoding, + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: save json in default encoding - UTF-8") { + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write.json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = "UTF-8", + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: wrong output encoding") { + val encoding = "UTF-128" + val exception = intercept[SparkException] { + withTempPath { path => + val df = spark.createDataset(Seq((0))) + df.write + .options(Map("encoding" -> encoding)) + .json(path.getCanonicalPath) + } + } + + val baos = new ByteArrayOutputStream() + val ps = new PrintStream(baos, true, "UTF-8") + exception.printStackTrace(ps) + ps.flush() + + assert(baos.toString.contains( + "java.nio.charset.UnsupportedCharsetException: UTF-128")) + } + + test("SPARK-23723: read back json in UTF-16LE") { + val options = Map("encoding" -> "UTF-16LE", "lineSep" -> "\n") + withTempPath { path => + val ds = spark.createDataset(Seq(("a", 1), ("b", 2), ("c", 3))).repartition(2) + ds.write.options(options).json(path.getCanonicalPath) + + val readBack = spark + .read + .options(options) + .json(path.getCanonicalPath) + + checkAnswer(readBack.toDF(), ds.toDF()) + } + } + + test("SPARK-23723: write json in UTF-16/32 with multiline off") { + Seq("UTF-16", "UTF-32").foreach { encoding => + withTempPath { path => + val ds = spark.createDataset(Seq(("a", 1))).repartition(1) + ds.write + .option("encoding", encoding) + .option("multiline", false) + .json(path.getCanonicalPath) + val jsonFiles = path.listFiles().filter(_.getName.endsWith("json")) + jsonFiles.foreach { jsonFile => + val readback = Files.readAllBytes(jsonFile.toPath) + val expected = ("""{"_1":"a","_2":1}""" + "\n").getBytes(Charset.forName(encoding)) + assert(readback === expected) + } + } + } + } + + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { + val schema = new StructType().add("f1", StringType).add("f2", IntegerType) + withTempPath { path => + val records = List(("a", 1), ("b", 2)) + val data = records + .map(rec => s"""{"f1":"${rec._1}", "f2":${rec._2}}""".getBytes(encoding)) + .reduce((a1, a2) => a1 ++ lineSep.getBytes(encoding) ++ a2) + val os = new FileOutputStream(path) + os.write(data) + os.close() + val reader = if (inferSchema) { + spark.read } else { - writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") + spark.read.schema(schema) } + val readBack = reader + .option("encoding", encoding) + .option("lineSep", lineSep) + .json(path.getCanonicalPath) + checkAnswer(readBack, records.map(rec => Row(rec._1, rec._2))) } - writer.close() + } + } - val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) - assert(ds.schema == new StructType().add("f1", LongType)) + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, "::", "ISO-8859-1", true), + (3, "!!!@3", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "куку", "CP1251", true), + (7, "sep", "utf-8", false), + (8, "\r\n", "UTF-16LE", false), + (9, "\r\n", "utf-16be", true), + (10, "\u000d\u000a", "UTF-32BE", false), + (11, "\u000a\u000d", "UTF-8", true), + (12, "===", "US-ASCII", false), + (13, "$^+", "utf-32le", true) + ).foreach { + case (testNum, sep, encoding, inferSchema) => checkReadJson(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("SPARK-23724: lineSep should be set if encoding if different from UTF-8") { + val encoding = "UTF-16LE" + val exception = intercept[IllegalArgumentException] { + spark.read + .options(Map("encoding" -> encoding)) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains( + s"""The lineSep option must be specified for the $encoding encoding""")) + } + + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is enabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson + """{"a":1}""").toDS().write.text(path) + val expected = s"""${badJson}{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", true) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Row(null, expected)) } } - test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { - val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - if (predefinedSample.contains(i)) { - s"""{"f1":${i.toString}}""" + "\n" - } else { - s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is disabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", false) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("SPARK-23094: permissively parse a dataset contains JSON with leading nulls") { + checkAnswer( + spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), + Row(badJson)) + } + + test("SPARK-23772 ignore column of all null values or empty array during schema inference") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + + // primitive types + Seq( + """{"a":null, "b":1, "c":3.0}""", + """{"a":null, "b":null, "c":"string"}""", + """{"a":null, "b":null, "c":null}""") + .toDS().write.text(path) + var df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + var expectedSchema = new StructType() + .add("b", LongType).add("c", StringType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil) + + // arrays + Seq( + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""", + """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""", + """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", ArrayType(LongType)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil) + + // structs + Seq( + """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""", + """{"a":null, "b":null}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) + :: Nil)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) + } + } + + test("SPARK-24190: restrictions for JSONOptions in read") { + for (encoding <- Set("UTF-16", "UTF-32")) { + val exception = intercept[IllegalArgumentException] { + spark.read + .option("encoding", encoding) + .option("multiLine", false) + .json(testFile("test-data/utf16LE.json")) + .count() } - }.toDS() - val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + assert(exception.getMessage.contains("encoding must not be included in the blacklist")) + } + } + + test("count() for malformed input") { + def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", StringType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).json(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = """{"a":"b"}""" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("}", validRec), + Seq(validRec, """{"a": [1, 2, 3]}"""), + Seq("""{"a": {"a": "b"}}""", validRec) + ) + inputs.foreach { input => + countForMalformedJSON(expected, input) + } + } - assert(ds.schema == new StructType().add("f1", LongType)) + checkCount(2) + countForMalformedJSON(0, Seq("")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 13084ba4a7f04..6e9559edf8ec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -233,4 +233,16 @@ private[json] trait TestJsonData { spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + s"""{"f1":${index.toString}}""" + } else { + s"""{"f1":${(index.toDouble + 0.1).toString}}""" + } + }(Encoders.STRING) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index f58c331f33ca8..e9dccbf2e261c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -562,20 +562,57 @@ abstract class OrcQueryTest extends OrcTest { } } + def testAllCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.json(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.json(new Path(basePath, "second").toString) + val df = spark.read.orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString) + assert(df.count() == 0) + } + } + + def testAllCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.json(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.json(new Path(basePath, "second").toString) + val df = spark.read.schema("a long").orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString) + assert(df.count() == 0) + } + } + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { testIgnoreCorruptFiles() testIgnoreCorruptFilesWithoutSchemaInfer() + val m1 = intercept[AnalysisException] { + testAllCorruptFiles() + }.getMessage + assert(m1.contains("Unable to infer schema for ORC")) + testAllCorruptFilesWithoutSchemaInfer() } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { val m1 = intercept[SparkException] { testIgnoreCorruptFiles() }.getMessage - assert(m1.contains("Could not read footer for file")) + assert(m1.contains("Malformed ORC file")) val m2 = intercept[SparkException] { testIgnoreCorruptFilesWithoutSchemaInfer() }.getMessage assert(m2.contains("Malformed ORC file")) + val m3 = intercept[SparkException] { + testAllCorruptFiles() + }.getMessage + assert(m3.contains("Could not read footer for file")) + val m4 = intercept[SparkException] { + testAllCorruptFilesWithoutSchemaInfer() + }.getMessage + assert(m4.contains("Malformed ORC file")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 8a3bbd03a26dc..02bfb7197ffc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import java.sql.Timestamp import java.util.Locale import org.apache.orc.OrcConf.COMPRESS @@ -169,6 +170,14 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-24322 Fix incorrect workaround for bug in java.sql.Timestamp") { + withTempPath { path => + val ts = Timestamp.valueOf("1900-05-05 12:34:56.000789") + Seq(ts).toDF.write.orc(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), Row(ts)) + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index f3ecc5ced689f..4b2437803d645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -91,9 +91,14 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils summary: Boolean, check: Boolean): Option[FileStatus] = { var result: Option[FileStatus] = None + val summaryLevel = if (summary) { + "ALL" + } else { + "NONE" + } withSQLConf( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> summaryLevel) { withTempPath { dest => val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) val destPath = new Path(dest.toURI) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala index ed8fd2b453456..09de715e87a11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.test.SharedSQLContext class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { test("Test `spark.sql.parquet.compression.codec` config") { - Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO", "LZ4", "BROTLI", "ZSTD").foreach { c => withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { val expected = if (c == "NONE") "UNCOMPRESSED" else c val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index 3a0867fd2b78b..94abf115cef35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo } testReadFooters(true) - val exception = intercept[java.io.IOException] { + val exception = intercept[SparkException] { testReadFooters(false) - } + }.getCause assert(exception.getMessage().contains("Could not read footer for file")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1d3476e747046..be4f498c921ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets -import java.sql.Date +import java.sql.{Date, Timestamp} -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} @@ -31,6 +32,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} @@ -55,6 +57,11 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + private lazy val parquetFilters = + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold) + override def beforeEach(): Unit = { super.beforeEach() // Note that there are many tests here that require record-level filtering set to be true. @@ -80,6 +87,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) @@ -99,7 +109,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -138,6 +149,71 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + assert(data.size === 4) + val ts1 = data.head + val ts2 = data(1) + val ts3 = data(2) + val ts4 = data(3) + + withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) + + checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + + checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1) + checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1) + checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4) + + checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1) + checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4) + checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1) + checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4) + + checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4) + checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or], Seq(Row(ts1), Row(ts4))) + } + } + + private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit = { + withTempPath { file => + data.write.parquet(file.getCanonicalPath) + readParquetFile(file.toString)(f) + } + } + + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. + private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + dataFrame.write.option("parquet.block.size", 512).parquet(path) + Seq(true, false).foreach { pushDown => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> pushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) + + val df = spark.read.parquet(path).filter(filter) + df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) + if (pushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } + + AccumulatorContext.remove(accu.id) + } + } + } + } + test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -149,6 +225,62 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - tinyint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => + assert(df.schema.head.dataType === ByteType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - smallint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + assert(df.schema.head.dataType === ShortType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -357,6 +489,117 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - timestamp") { + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS + val millisData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123"), + Timestamp.valueOf("2018-06-15 08:28:53.123"), + Timestamp.valueOf("2018-06-16 08:28:53.123"), + Timestamp.valueOf("2018-06-17 08:28:53.123")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { + testTimestampPushdown(millisData) + } + + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS + val microsData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123456"), + Timestamp.valueOf("2018-06-15 08:28:53.123456"), + Timestamp.valueOf("2018-06-16 08:28:53.123456"), + Timestamp.valueOf("2018-06-17 08:28:53.123456")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { + testTimestampPushdown(microsData) + } + + // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.INT96.toString) { + withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df => + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.IsNull("_1")) + } + } + } + } + + test("filter pushdown - decimal") { + Seq(true, false).foreach { legacyFormat => + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { + Seq( + s"a decimal(${Decimal.MAX_INT_DIGITS}, 2)", // 32BitDecimalType + s"a decimal(${Decimal.MAX_LONG_DIGITS}, 2)", // 64BitDecimalType + "a decimal(38, 18)" // ByteArrayDecimalType + ).foreach { schemaDDL => + val schema = StructType.fromDDL(schemaDDL) + val rdd = + spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) + val dataFrame = spark.createDataFrame(rdd, schema) + testDecimalPushDown(dataFrame) { implicit df => + assert(df.schema === schema) + checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('a === 1, classOf[Eq[_]], 1) + checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('a < 2, classOf[Lt[_]], 1) + checkFilterPredicate('a > 3, classOf[Gt[_]], 4) + checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + } + } + } + + test("Ensure that filter value matched the parquet file schema") { + val scale = 2 + val schema = StructType(Seq( + StructField("cint", IntegerType), + StructField("cdecimal1", DecimalType(Decimal.MAX_INT_DIGITS, scale)), + StructField("cdecimal2", DecimalType(Decimal.MAX_LONG_DIGITS, scale)), + StructField("cdecimal3", DecimalType(DecimalType.MAX_PRECISION, scale)) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + val decimal = new JBigDecimal(10).setScale(scale) + val decimal1 = new JBigDecimal(10).setScale(scale + 1) + assert(decimal.scale() === scale) + assert(decimal1.scale() === scale + 1) + + assertResult(Some(lt(intColumn("cdecimal1"), 1000: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal1)) + } + + assertResult(Some(lt(longColumn("cdecimal2"), 1000L: java.lang.Long))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal1)) + } + + assert(parquetFilters.createFilter( + parquetSchema, sources.LessThan("cdecimal3", decimal)).isDefined) + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal3", decimal1)) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ @@ -513,28 +756,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex StructField("c", DoubleType, nullable = true) )) + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + assertResult(Some(and( lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { - ParquetFilters.createFilter( - schema, + parquetFilters.createFilter( + parquetSchema, sources.And( sources.LessThan("a", 10), sources.GreaterThan("c", 1.5D))) } assertResult(None) { - ParquetFilters.createFilter( - schema, + parquetFilters.createFilter( + parquetSchema, sources.And( sources.LessThan("a", 10), sources.StringContains("b", "prefix"))) } assertResult(None) { - ParquetFilters.createFilter( - schema, + parquetFilters.createFilter( + parquetSchema, sources.Not( sources.And( sources.GreaterThan("a", 1), @@ -572,7 +817,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) - df.collect if (enablePushDown) { assert(accu.value == 0) @@ -587,21 +831,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-17213: Broken Parquet filter push-down for string columns") { - withTempPath { dir => - import testImplicits._ + Seq(true, false).foreach { vectorizedEnabled => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedEnabled.toString) { + withTempPath { dir => + import testImplicits._ - val path = dir.getCanonicalPath - // scalastyle:off nonascii - Seq("a", "é").toDF("name").write.parquet(path) - // scalastyle:on nonascii + val path = dir.getCanonicalPath + // scalastyle:off nonascii + Seq("a", "é").toDF("name").write.parquet(path) + // scalastyle:on nonascii - assert(spark.read.parquet(path).where("name > 'a'").count() == 1) - assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) + assert(spark.read.parquet(path).where("name > 'a'").count() == 1) + assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) - // scalastyle:off nonascii - assert(spark.read.parquet(path).where("name < 'é'").count() == 1) - assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) - // scalastyle:on nonascii + // scalastyle:off nonascii + assert(spark.read.parquet(path).where("name < 'é'").count() == 1) + assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) + // scalastyle:on nonascii + } + } } } @@ -646,6 +894,133 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-23852: Broken Parquet push-down for partially-written stats") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. + // The row-group statistics include null counts, but not min and max values, which + // triggers PARQUET-1217. + val df = readResourceParquetFile("test-data/parquet-1217.parquet") + + // Will return 0 rows if PARQUET-1217 is not fixed. + assert(df.where("col > 0").count() === 2) + } + } + + test("filter pushdown - StringStartsWith") { + withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => + checkFilterPredicate( + '_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "2str2") + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq.empty[Row]) + } + + checkFilterPredicate( + !'_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq().map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "3str3", "4str4").map(Row(_))) + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + } + + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), + sources.StringStartsWith("_1", null)) + } + } + + import testImplicits._ + // Test canDrop() has taken effect + testStringStartsWith(spark.range(1024).map(_.toString).toDF(), "value like 'a%'") + // Test inverseCanDrop() has taken effect + testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") + } + + test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(null))) + } + + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10))) + } + + // Remove duplicates + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 10))) + } + + assertResult(Some(or(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer)), + FilterApi.eq(intColumn("a"), 30: Integer))) + ) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20, 30))) + } + + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined) + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold + 1).toArray)).isEmpty) + + import testImplicits._ + withTempPath { path => + val data = 0 to 1024 + data.toDF("a").selectExpr("if (a = 1024, null, a) AS a") // convert 1024 to null + .coalesce(1).write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + Seq(true, false).foreach { pushEnabled => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> pushEnabled.toString) { + Seq(1, 5, 10, 11).foreach { count => + val filter = s"a in(${Range(0, count).mkString(",")})" + assert(df.where(filter).count() === count) + val actual = stripSparkFilter(df.where(filter)).collect().length + if (pushEnabled && count <= conf.parquetFilterPushDownInFilterThreshold) { + assert(actual > 1 && actual < data.length) + } else { + assert(actual === data.length) + } + } + assert(df.where("a in(null)").count() === 0) + assert(df.where("a = null").count() === 0) + assert(df.where("a is null").count() === 1) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0b3e8ca060d87..002c42f23bd64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -543,7 +543,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) - withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + withSQLConf(ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part-r-0.parquet" spark.range(1 << 16).selectExpr("(id % 4) AS i") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 9c75965639d8a..f06e1867151e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import scala.language.existentials + import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER @@ -175,8 +177,9 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS val oneFooter = ParquetFileReader.readFooter(hadoopConf, part.getPath, NO_FILTER) assert(oneFooter.getFileMetaData.getSchema.getColumns.size === 1) - assert(oneFooter.getFileMetaData.getSchema.getColumns.get(0).getType() === - PrimitiveTypeName.INT96) + val typeName = oneFooter + .getFileMetaData.getSchema.getColumns.get(0).getPrimitiveType.getPrimitiveTypeName + assert(typeName === PrimitiveTypeName.INT96) val oneBlockMeta = oneFooter.getBlocks().get(0) val oneBlockColumnMeta = oneBlockMeta.getColumns().get(0) val columnStats = oneBlockColumnMeta.getStatistics diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index e887c9734a8b8..9966ed94a8392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1014,7 +1014,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val path = dir.getCanonicalPath withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", "spark.sql.sources.commitProtocolClass" -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { spark.range(3).write.parquet(s"$path/p0=0/p1=0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e1f094d0a7af3..54c77dddc3525 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -108,7 +108,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -117,7 +117,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -126,7 +126,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) @@ -275,7 +275,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName, SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true", - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true" + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL" ) { testSchemaMerging(2) } @@ -879,6 +879,18 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-24230: filter row group using dictionary") { + withSQLConf(("parquet.filter.dictionary.enabled", "true")) { + // create a table with values from 0, 2, ..., 18 that will be dictionary-encoded + withParquetTable((0 until 100).map(i => ((i * 2) % 20, s"data-$i")), "t") { + // search for a key that is not present so the dictionary filter eliminates all row groups + // Fails without SPARK-24230: + // java.io.IOException: expecting more rows but reached last block. Read 0 out of 50 + checkAnswer(sql("SELECT _2 FROM t WHERE t._1 = 5"), Seq.empty) + } + } + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala deleted file mode 100644 index e43336d947364..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.parquet - -import java.io.File - -import scala.collection.JavaConverters._ -import scala.util.Try - -import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - -/** - * Benchmark to measure parquet read performance. - * To run this: - * spark-submit --class --jars - */ -object ParquetReadBenchmark { - val conf = new SparkConf() - conf.set("spark.sql.parquet.compression.codec", "snappy") - - val spark = SparkSession.builder - .master("local[1]") - .appName("test-sql-context") - .config(conf) - .getOrCreate() - - // Set default configs. Individual cases will change them if necessary. - spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - def intScanBenchmark(values: Int): Unit = { - // Benchmarks running through spark sql. - val sqlBenchmark = new Benchmark("SQL Single Int Column Scan", values) - // Benchmarks driving reader component directly. - val parquetReaderBenchmark = new Benchmark("Parquet Reader Single Int Column Scan", values) - - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as id from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(id) from tempTable").collect() - } - - sqlBenchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(id) from tempTable").collect() - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - // Driving the parquet reader in batch mode directly. - parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - val col = batch.column(0) - while (reader.nextBatch()) { - val numRows = batch.numRows() - var i = 0 - while (i < numRows) { - if (!col.isNullAt(i)) sum += col.getInt(i) - i += 1 - } - } - } finally { - reader.close() - } - } - } - - // Decoding in vectorized but having the reader return rows. - parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val it = batch.rowIterator() - while (it.hasNext) { - val record = it.next() - if (!record.isNullAt(0)) sum += record.getInt(0) - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 215 / 262 73.0 13.7 1.0X - SQL Parquet MR 1946 / 2083 8.1 123.7 0.1X - */ - sqlBenchmark.run() - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - ParquetReader Vectorized 123 / 152 127.8 7.8 1.0X - ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 0.7X - */ - parquetReaderBenchmark.run() - } - } - } - - def intStringScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Int and String Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 628 / 720 16.7 59.9 1.0X - SQL Parquet MR 1905 / 2239 5.5 181.7 0.3X - */ - benchmark.run() - } - } - } - - def stringDictionaryScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String Dictionary", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c1)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(length(c1)) from tempTable").collect - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 329 / 337 31.9 31.4 1.0X - SQL Parquet MR 1131 / 1325 9.3 107.8 0.3X - */ - benchmark.run() - } - } - } - - def partitionTableScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select id % 2 as p, cast(id as INT) as id from t1") - .write.partitionBy("p").parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Partitioned Table", values) - - benchmark.addCase("Read data column") { iter => - spark.sql("select sum(id) from tempTable").collect - } - - benchmark.addCase("Read partition column") { iter => - spark.sql("select sum(p) from tempTable").collect - } - - benchmark.addCase("Read both columns") { iter => - spark.sql("select sum(p), sum(id) from tempTable").collect - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Read data column 191 / 250 82.1 12.2 1.0X - Read partition column 82 / 86 192.4 5.2 2.3X - Read both columns 220 / 248 71.5 14.0 0.9X - */ - benchmark.run() - } - } - } - - def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + - s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String with Nulls Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c2)) from tempTable where c1 is " + - "not NULL and c2 is not NULL").collect() - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - benchmark.addCase("PR Vectorized") { num => - var sum = 0 - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val rowIterator = batch.rowIterator() - while (rowIterator.hasNext) { - val row = rowIterator.next() - val value = row.getUTF8String(0) - if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X - PR Vectorized 833 / 846 12.6 79.4 1.5X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X - PR Vectorized 732 / 772 14.3 69.8 1.4X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X - PR Vectorized 190 / 200 55.1 18.2 1.7X - */ - - benchmark.run() - } - } - } - - def main(args: Array[String]): Unit = { - intScanBenchmark(1024 * 1024 * 15) - intStringScanBenchmark(1024 * 1024 * 10) - stringDictionaryScanBenchmark(1024 * 1024 * 10) - partitionTableScanBenchmark(1024 * 1024 * 15) - for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { - stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala new file mode 100644 index 0000000000000..eb99654fa78f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.scalactic.Equality + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class ParquetSchemaPruningSuite + extends QueryTest + with ParquetTest + with SchemaPruningTest + with SharedSQLContext { + case class FullName(first: String, middle: String, last: String) + case class Contact( + id: Int, + name: FullName, + address: String, + pets: Int, + friends: Array[FullName] = Array.empty, + relatives: Map[String, FullName] = Map.empty) + + val janeDoe = FullName("Jane", "X.", "Doe") + val johnDoe = FullName("John", "Y.", "Doe") + val susanSmith = FullName("Susan", "Z.", "Smith") + + private val contacts = + Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith), + relatives = Map("brother" -> johnDoe)) :: + Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe)) :: Nil + + case class Name(first: String, last: String) + case class BriefContact(id: Int, name: Name, address: String) + + private val briefContacts = + BriefContact(2, Name("Janet", "Jones"), "567 Maple Drive") :: + BriefContact(3, Name("Jim", "Jones"), "6242 Ash Street") :: Nil + + case class ContactWithDataPartitionColumn( + id: Int, + name: FullName, + address: String, + pets: Int, + friends: Array[FullName] = Array(), + relatives: Map[String, FullName] = Map(), + p: Int) + + case class BriefContactWithDataPartitionColumn(id: Int, name: Name, address: String, p: Int) + + private val contactsWithDataPartitionColumn = + contacts.map { case Contact(id, name, address, pets, friends, relatives) => + ContactWithDataPartitionColumn(id, name, address, pets, friends, relatives, 1) } + private val briefContactsWithDataPartitionColumn = + briefContacts.map { case BriefContact(id, name, address) => + BriefContactWithDataPartitionColumn(id, name, address, 2) } + + testSchemaPruning("select a single complex field") { + val query = sql("select name.middle from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), Row("X.") :: Row("Y.") :: Row(null) :: Row(null) :: Nil) + } + + testSchemaPruning("select a single complex field and its parent struct") { + val query = sql("select name.middle, name from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("X.", Row("Jane", "X.", "Doe")) :: + Row("Y.", Row("John", "Y.", "Doe")) :: + Row(null, Row("Janet", null, "Jones")) :: + Row(null, Row("Jim", null, "Jones")) :: + Nil) + } + + testSchemaPruning("select a single complex field array and its parent struct array") { + val query = sql("select friends.middle, friends from contacts where p=1") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row(Array("Z."), Array(Row("Susan", "Z.", "Smith"))) :: + Row(Array.empty[String], Array.empty[Row]) :: + Nil) + } + + testSchemaPruning("select a single complex field from a map entry and its parent map entry") { + val query = + sql("select relatives[\"brother\"].middle, relatives[\"brother\"] from contacts where p=1") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row("Y.", Row("John", "Y.", "Doe")) :: + Row(null, null) :: + Nil) + } + + testSchemaPruning("select a single complex field and the partition column") { + val query = sql("select name.middle, p from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("X.", 1) :: Row("Y.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil) + } + + ignore("partial schema intersection - select missing subfield") { + val query = sql("select name.middle, address from contacts where p=2") + checkScan(query, "struct,address:string>") + checkAnswer(query.orderBy("id"), + Row(null, "567 Maple Drive") :: + Row(null, "6242 Ash Street") :: Nil) + } + + testSchemaPruning("no unnecessary schema pruning") { + val query = + sql("select id, name.last, name.middle, name.first, relatives[''].last, " + + "relatives[''].middle, relatives[''].first, friends[0].last, friends[0].middle, " + + "friends[0].first, pets, address from contacts where p=2") + // We've selected every field in the schema. Therefore, no schema pruning should be performed. + // We check this by asserting that the scanned schema of the query is identical to the schema + // of the contacts relation, even though the fields are selected in different orders. + checkScan(query, + "struct,address:string,pets:int," + + "friends:array>," + + "relatives:map>>") + checkAnswer(query.orderBy("id"), + Row(2, "Jones", null, "Janet", null, null, null, null, null, null, null, "567 Maple Drive") :: + Row(3, "Jones", null, "Jim", null, null, null, null, null, null, null, "6242 Ash Street") :: + Nil) + } + + testSchemaPruning("empty schema intersection") { + val query = sql("select name.middle from contacts where p=2") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row(null) :: Row(null) :: Nil) + } + + private def testSchemaPruning(testName: String)(testThunk: => Unit) { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + test(s"Spark vectorized reader - without partition data column - $testName") { + withContacts(testThunk) + } + test(s"Spark vectorized reader - with partition data column - $testName") { + withContactsWithDataPartitionColumn(testThunk) + } + } + + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + test(s"Parquet-mr reader - without partition data column - $testName") { + withContacts(testThunk) + } + test(s"Parquet-mr reader - with partition data column - $testName") { + withContactsWithDataPartitionColumn(testThunk) + } + } + } + + private def withContacts(testThunk: => Unit) { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contacts, new File(path + "/contacts/p=1")) + makeParquetFile(briefContacts, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + testThunk + } + } + + private def withContactsWithDataPartitionColumn(testThunk: => Unit) { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1")) + makeParquetFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + testThunk + } + } + + case class MixedCaseColumn(a: String, B: Int) + case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn) + + private val mixedCaseData = + MixedCase(0, "r0c1", MixedCaseColumn("abc", 1)) :: + MixedCase(1, "r1c1", MixedCaseColumn("123", 2)) :: + Nil + + testMixedCasePruning("select with exact column names") { + val query = sql("select CoL1, coL2.B from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCasePruning("select with lowercase column names") { + val query = sql("select col1, col2.b from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCasePruning("select with different-case column names") { + val query = sql("select cOL1, cOl2.b from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCasePruning("filter with different-case column names") { + val query = sql("select id from mixedcase where Col2.b = 2") + // Pruning with filters is currently unsupported. As-is, the file reader will read the id column + // and the entire coL2 struct. Once pruning with filters has been implemented we can uncomment + // this line + // checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), Row(1) :: Nil) + } + + private def testMixedCasePruning(testName: String)(testThunk: => Unit) { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.CASE_SENSITIVE.key -> "true") { + test(s"Spark vectorized reader - case-sensitive parser - mixed-case schema - $testName") { + withMixedCaseData(testThunk) + } + } + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.CASE_SENSITIVE.key -> "false") { + test(s"Parquet-mr reader - case-insensitive parser - mixed-case schema - $testName") { + withMixedCaseData(testThunk) + } + } + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.CASE_SENSITIVE.key -> "false") { + test(s"Spark vectorized reader - case-insensitive parser - mixed-case schema - $testName") { + withMixedCaseData(testThunk) + } + } + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.CASE_SENSITIVE.key -> "true") { + test(s"Parquet-mr reader - case-sensitive parser - mixed-case schema - $testName") { + withMixedCaseData(testThunk) + } + } + } + + private def withMixedCaseData(testThunk: => Unit) { + withParquetTable(mixedCaseData, "mixedcase") { + testThunk + } + } + + private val schemaEquality = new Equality[StructType] { + override def areEqual(a: StructType, b: Any): Boolean = + b match { + case otherType: StructType => a.sameType(otherType) + case _ => false + } + } + + protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + checkScanSchemata(df, expectedSchemaCatalogStrings: _*) + // We check here that we can execute the query without throwing an exception. The results + // themselves are irrelevant, and should be checked elsewhere as needed + df.collect() + } + + private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + df.queryExecution.executedPlan.collect { + case scan: FileSourceScanExec => scan.requiredSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 9d3dfae348beb..7eefedb8ff5bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -430,9 +430,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) assert(col.length == 1) if (col(0).dataType == StringType) { - assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + assert(errMsg.contains("Column: [a], Expected: int, Found: BINARY")) } else { - assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + assert(errMsg.endsWith("Column: [a], Expected: string, Found: INT32")) } } } @@ -1014,19 +1014,21 @@ class ParquetSchemaSuite extends ParquetSchemaTest { testName: String, parquetSchema: String, catalystSchema: StructType, - expectedSchema: String): Unit = { + expectedSchema: String, + caseSensitive: Boolean = true): Unit = { testSchemaClipping(testName, parquetSchema, catalystSchema, - MessageTypeParser.parseMessageType(expectedSchema)) + MessageTypeParser.parseMessageType(expectedSchema), caseSensitive) } private def testSchemaClipping( testName: String, parquetSchema: String, catalystSchema: StructType, - expectedSchema: MessageType): Unit = { + expectedSchema: MessageType, + caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) try { expectedSchema.checkContains(actual) @@ -1387,7 +1389,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { catalystSchema = new StructType(), - expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE) + expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE, + caseSensitive = true) testSchemaClipping( "disjoint field sets", @@ -1544,4 +1547,52 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin) + + testSchemaClipping( + "case-insensitive resolution: no ambiguity", + parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + }, + expectedSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + caseSensitive = false) + + test("Clipping - case-insensitive resolution: more than one field is matched") { + val parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + | optional int32 a; + |} + """.stripMargin + val catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + } + assertThrows[RuntimeException] { + ParquetReadSupport.clipParquetSchema( + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index fff0f82f9bc2b..a302d67b5cbf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types.{StringType, StructType} -class WholeTextFileSuite extends QueryTest with SharedSQLContext { +class WholeTextFileSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which // can cause Filesystem.get(Configuration) to return a cached instance created with a different @@ -35,13 +35,10 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { protected override def sparkConf = super.sparkConf.set("spark.hadoop.fs.file.impl.disable.cache", "true") - private def testFile: String = { - Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString - } - test("reading text file with option wholetext=true") { val df = spark.read.option("wholetext", "true") - .format("text").load(testFile) + .format("text") + .load(testFile("test-data/text-suite.txt")) // schema assert(df.schema == new StructType().add("value", StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index adcaf2d76519f..8251ff159e05f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.TestData @@ -33,14 +34,16 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } test("debugCodegenStringSeq") { - val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.length == 2) assert(res.forall{ case (subtree, code) => subtree.contains("Range") && code.contains("Object[]")}) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 51f8c3325fdff..d9b34dcd16476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer @@ -254,6 +254,59 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24257: insert big values into LongToUnsafeRowMap") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Array[DataType](StringType)) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + + val key = 0L + // the page array is initialized with length 1 << 17 (1M bytes), + // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug + val bigStr = UTF8String.fromString("x" * (1 << 19)) + + map.append(key, unsafeProj(InternalRow(bigStr))) + map.optimize() + + val resultRow = new UnsafeRow(1) + assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr) + map.free() + } + + test("SPARK-24809: Serializing LongToUnsafeRowMap in executor may result in data error") { + val unsafeProj = UnsafeProjection.create(Array[DataType](LongType)) + val originalMap = new LongToUnsafeRowMap(mm, 1) + + val key1 = 1L + val value1 = 4852306286022334418L + + val key2 = 2L + val value2 = 8813607448788216010L + + originalMap.append(key1, unsafeProj(InternalRow(value1))) + originalMap.append(key2, unsafeProj(InternalRow(value2))) + originalMap.optimize() + + val ser = sparkContext.env.serializer.newInstance() + // Simulate serialize/deserialize twice on driver and executor + val firstTimeSerialized = ser.deserialize[LongToUnsafeRowMap](ser.serialize(originalMap)) + val secondTimeSerialized = + ser.deserialize[LongToUnsafeRowMap](ser.serialize(firstTimeSerialized)) + + val resultRow = new UnsafeRow(1) + assert(secondTimeSerialized.getValue(key1, resultRow).getLong(0) === value1) + assert(secondTimeSerialized.getValue(key2, resultRow).getLong(0) === value2) + + originalMap.free() + firstTimeSerialized.free() + secondTimeSerialized.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index d456c931f5275..2cc55ff88b983 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -115,3 +115,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( dataType = BooleanType, pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) + +class MyDummyScalarPandasUDF extends UserDefinedPythonFunction( + name = "dummyScalarPandasUDF", + func = new DummyUDF, + dataType = BooleanType, + pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, + udfDeterministic = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala new file mode 100644 index 0000000000000..76b609d111acd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext + +class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder + + val batchedPythonUDF = new MyDummyPythonUDF + val scalarPandasUDF = new MyDummyScalarPandasUDF + + private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect { + case b: BatchEvalPythonExec => b + } + + private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect { + case b: ArrowEvalPythonExec => b + } + + test("Chained Batched Python UDFs should be combined to a single physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) + .withColumn("d", batchedPythonUDF(col("c"))) + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + } + + test("Chained Scalar Pandas UDFs should be combined to a single physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", scalarPandasUDF(col("a"))) + .withColumn("d", scalarPandasUDF(col("c"))) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(arrowEvalNodes.size == 1) + } + + test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) + .withColumn("d", scalarPandasUDF(col("b"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + assert(arrowEvalNodes.size == 1) + } + + test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) + .withColumn("c2", batchedPythonUDF(col("c1"))) + .withColumn("d1", scalarPandasUDF(col("a"))) + .withColumn("d2", scalarPandasUDF(col("d1"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + assert(arrowEvalNodes.size == 1) + } + + test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) + .withColumn("d1", scalarPandasUDF(col("c1"))) + .withColumn("c2", batchedPythonUDF(col("d1"))) + .withColumn("d2", scalarPandasUDF(col("c2"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 2) + assert(arrowEvalNodes.size == 2) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala new file mode 100644 index 0000000000000..07e6034770127 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils + +class PythonForeachWriterSuite extends SparkFunSuite with Eventually { + + testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => + b.assertIteratorBlocked() + + b.add(Seq(1)) + b.assertOutput(Seq(1)) + b.assertIteratorBlocked() + + b.add(2 to 100) + b.assertOutput(1 to 100) + b.assertIteratorBlocked() + } + + testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b => + b.assertIteratorBlocked() + b.add(Seq(1)) + b.assertIteratorBlocked() + + b.allAdded() + b.assertThreadTerminated() + b.assertOutput(Seq(1)) + } + + testWithBuffer( + "UnsafeRowBuffer: handles more data than memory", + memBytes = 5, + sleepPerRowReadMs = 1) { b => + + b.assertIteratorBlocked() + b.add(1 to 2000) + b.assertOutput(1 to 2000) + } + + def testWithBuffer( + name: String, + memBytes: Long = 4 << 10, + sleepPerRowReadMs: Int = 0 + )(f: BufferTester => Unit): Unit = { + + test(name) { + var tester: BufferTester = null + try { + tester = new BufferTester(memBytes, sleepPerRowReadMs) + f(tester) + } finally { + if (tester == null) tester.close() + } + } + } + + + class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { + private val buffer = { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(memBytes) + val taskM = new TaskMemoryManager(mem, 0) + new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) + } + private val iterator = buffer.iterator + private val outputBuffer = new ArrayBuffer[Int] + private val testTimeout = timeout(20.seconds) + private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val thread = new Thread() { + override def run(): Unit = { + while (iterator.hasNext) { + outputBuffer.synchronized { + outputBuffer += iterator.next().getInt(0) + } + Thread.sleep(sleepPerRowReadMs) + } + } + } + thread.start() + + def add(ints: Seq[Int]): Unit = { + ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) } + } + + def allAdded(): Unit = { buffer.allRowsAdded() } + + def assertOutput(expectedOutput: Seq[Int]): Unit = { + eventually(testTimeout) { + val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq + assert(output == expectedOutput) + } + } + + def assertIteratorBlocked(): Unit = { + import Thread.State._ + eventually(testTimeout) { + assert(thread.isAlive) + assert(thread.getState == TIMED_WAITING || thread.getState == WAITING) + } + } + + def assertThreadTerminated(): Unit = { + eventually(testTimeout) { assert(!thread.isAlive) } + } + + def close(): Unit = { + thread.interrupt() + thread.join() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e8420eee7fe9d..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 36) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 72) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b5654..50f13bee251ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.execution.streaming import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { val partition = 1234 - val writer = new MemoryDataWriter(partition, OutputMode.Append()) - writer.write(Row(1)) - writer.write(Row(2)) - writer.write(Row(44)) + val writer = new MemoryDataWriter( + partition, OutputMode.Append(), new StructType().add("i", "int")) + writer.write(InternalRow(1)) + writer.write(InternalRow(2)) + writer.write(InternalRow(44)) val msg = writer.commit() assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) assert(msg.partition == partition) @@ -38,10 +41,11 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(writer.commit().data.isEmpty) } - test("continuous writer") { + test("streaming writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) - writer.commit(0, + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), new StructType().add("i", "int")) + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -49,7 +53,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writer.commit(19, + writeSupport.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -60,24 +64,24 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } - test("microbatch writer") { + test("writer metrics") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + val schema = new StructType().add("i", "int") + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), schema) + // batch 0 + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))) )) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":6}") + // batch 1 + writeSupport.commit(1, Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) + MemoryWriterCommitMessage(0, Seq(Row(7), Row(8))) )) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) + assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":8}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala new file mode 100644 index 0000000000000..c228740df07c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.streaming.StreamTest + +class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("SPARK-24156: do not plan a no-data batch again after it has already been planned") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(df)( + AddData(inputData, 10, 11, 12, 13, 14, 15), // Set watermark to 5 + CheckAnswer(), + AddData(inputData, 25), // Set watermark to 15 to make MicroBatchExecution run no-data batch + CheckAnswer((10, 5)), // Last batch should be a no-data batch + StopStream, + Execute { q => + // Delete the last committed batch from the commit log to signify that the last batch + // (a no-data batch) never completed + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purgeAfter(commit - 1) + }, + // Add data before start so that MicroBatchExecution can plan a batch. It should not, + // it should first re-run the incomplete no-data batch and then run a new batch to process + // new data. + AddData(inputData, 30), + StartStream(), + CheckNewAnswer((15, 1)), // This should not throw the error reported in SPARK-24156 + StopStream, + Execute { q => + // Delete the entire commit log + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purge(commit + 1) + }, + AddData(inputData, 50), + StartStream(), + CheckNewAnswer((25, 1), (30, 1)) // This should not throw the error reported in SPARK-24156 + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala index 55acf2ba28d2f..5884380271f0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream -import org.scalatest.time.SpanSugar._ - import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.{StreamTest, Trigger} -class ConsoleWriterSuite extends StreamTest { +class ConsoleWriteSupportSuite extends StreamTest { import testImplicits._ test("microbatch - default") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala new file mode 100644 index 0000000000000..71dff443e8836 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ + +case class KV(key: Int, value: Long) + +class ForeachBatchSinkSuite extends StreamTest { + import testImplicits._ + + test("foreachBatch with non-stateful query") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), // out = in + 2 (i.e. 1 in query, 1 in writer) + check(in = 5, 6, 7)(out = 7, 8, 9)) + } + + test("foreachBatch with stateful query in update mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Update)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (1, 1L)), + check(in = 2, 3)(out = (0, 2L), (1, 2L))) + } + + test("foreachBatch with stateful query in complete mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Complete)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (0, 1L), (1, 1L)), + check(in = 2)(out = (0, 2L), (1, 1L))) + } + + test("foreachBatchSink does not affect metric generation") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), + checkMetrics) + } + + test("throws errors in invalid situations") { + val ds = MemoryStream[Int].toDS + val ex1 = intercept[IllegalArgumentException] { + ds.writeStream.foreachBatch(null.asInstanceOf[(Dataset[Int], Long) => Unit]).start() + } + assert(ex1.getMessage.contains("foreachBatch function cannot be null")) + val ex2 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 second")).start() + } + assert(ex2.getMessage.contains("'foreachBatch' is not supported with continuous trigger")) + val ex3 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start() + } + assert(ex3.getMessage.contains("'foreachBatch' does not support partitioning")) + } + + // ============== Helper classes and methods ================= + + private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { + trait Test + private case class Check(in: Seq[Int], out: Seq[T]) extends Test + private case object CheckMetrics extends Test + + private val recordedOutput = new mutable.HashMap[Long, Seq[T]] + + def testWriter( + ds: Dataset[T], + outputBatchWriter: (Dataset[T], Long) => Unit, + outputMode: OutputMode = OutputMode.Append())(tests: Test*): Unit = { + try { + var expectedBatchId = -1 + val query = ds.writeStream.outputMode(outputMode).foreachBatch(outputBatchWriter).start() + + tests.foreach { + case Check(in, out) => + expectedBatchId += 1 + memoryStream.addData(in) + query.processAllAvailable() + assert(recordedOutput.contains(expectedBatchId)) + val ds: Dataset[T] = spark.createDataset[T](recordedOutput(expectedBatchId)) + checkDataset[T](ds, out: _*) + case CheckMetrics => + assert(query.recentProgress.exists(_.numInputRows > 0)) + } + } finally { + sqlContext.streams.active.foreach(_.stop()) + } + } + + def check(in: Int*)(out: T*): Test = Check(in, out) + def checkMetrics: Test = CheckMetrics + def record(batchId: Long, ds: Dataset[T]): Unit = recordedOutput.put(batchId, ds.collect()) + implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 03bf71b3f4b78..e60c339bc9cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -211,14 +211,12 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd try { inputData.addData(10, 11, 12) query.processAllAvailable() - inputData.addData(25) // Advance watermark to 15 seconds - query.processAllAvailable() inputData.addData(25) // Evict items less than previous watermark query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. val allEvents = ForeachWriterSuite.allEvents() - assert(allEvents.size === 3) + assert(allEvents.size === 4) val expectedEvents = Seq( Seq( ForeachWriterSuite.Open(partition = 0, version = 0), @@ -230,6 +228,10 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd ), Seq( ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 3), ForeachWriterSuite.Process(value = 3), ForeachWriterSuite.Close(None) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index ff14ec38e66a8..dd74af873c2e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.execution.streaming.sources -import java.nio.file.Files -import java.util.Optional import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -43,7 +41,7 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -54,19 +52,22 @@ class RateSourceSuite extends StreamTest { } test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find read support for rate") + withTempDir { temp => + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupportProvider => + val readSupport = ds.createMicroBatchReadSupport( + temp.getCanonicalPath, DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } } } test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[RateStreamProvider]) case _ => throw new IllegalStateException("Could not find read support for rate") @@ -81,12 +82,43 @@ class RateSourceSuite extends StreamTest { .load() testStream(input)( AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - restart") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .load() + .select('value) + + var streamDuration = 0 + + // Microbatch rate stream offsets contain the number of seconds since the beginning of + // the stream. + def updateStreamDurationFromOffset(s: StreamExecution, expectedMin: Int): Unit = { + streamDuration = s.lastProgress.sources(0).endOffset.toInt + assert(streamDuration >= expectedMin) + } + + // We have to use the lambda version of CheckAnswer because we don't know the right range + // until we see the last offset. + def expectedResultsFromDuration(rows: Seq[Row]): Unit = { + assert(rows.map(_.getLong(0)).sorted == (0 until (streamDuration * 10))) + } + + testStream(input)( + StartStream(), + Execute(_.awaitOffset(0, LongOffset(2), streamingTimeout.toMillis)), StopStream, + Execute(updateStreamDurationFromOffset(_, 2)), + CheckAnswer(expectedResultsFromDuration _), StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + Execute(_.awaitOffset(0, LongOffset(4), streamingTimeout.toMillis)), + StopStream, + Execute(updateStreamDurationFromOffset(_, 4)), + CheckAnswer(expectedResultsFromDuration _) ) } @@ -107,70 +139,67 @@ class RateSourceSuite extends StreamTest { ) } - test("microbatch - set offset") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - test("microbatch - infer offsets") { - val tempFolder = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions( - Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), - tempFolder) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: LongOffset => assert(r.offset === 0L) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: LongOffset => assert(r.offset >= 100) - case _ => throw new IllegalStateException("unexpected offset type") + withTempDir { temp => + val readSupport = new RateStreamMicroBatchReadSupport( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + temp.getCanonicalPath) + readSupport.clock.asInstanceOf[ManualClock].advance(100000) + val startOffset = readSupport.initialOffset() + startOffset match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + readSupport.latestOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } } } test("microbatch - predetermined batch size") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() - val data = ArrayBuffer[Row]() - while (dataReader.next()) { - data.append(dataReader.get()) + withTempDir { temp => + val readSupport = new RateStreamMicroBatchReadSupport( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), + temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) + assert(tasks.size == 1) + val dataReader = readerFactory.createReader(tasks(0)) + val data = ArrayBuffer[InternalRow]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) } - assert(data.size === 20) } test("microbatch - data read") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } + withTempDir { temp => + val readSupport = new RateStreamMicroBatchReadSupport( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), + temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) + assert(tasks.size == 11) + + val readData = tasks + .map(readerFactory.createReader) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[InternalRow]() + while (reader.next()) buf.append(reader.get()) + buf + } - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + assert(readData.map(_.getLong(1)).sorted === 0.until(33).toArray) + } } test("valueAtSecond") { @@ -280,41 +309,44 @@ class RateSourceSuite extends StreamTest { } test("user-specified schema given") { - val exception = intercept[AnalysisException] { + val exception = intercept[UnsupportedOperationException] { spark.readStream .format("rate") .schema(spark.range(1).schema) .load() } assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) + "rate source does not support user-specified schema")) } test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) + case ds: ContinuousReadSupportProvider => + val readSupport = ds.createContinuousReadSupport( + "", DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport]) case _ => throw new IllegalStateException("Could not find read support for continuous rate") } } test("continuous data") { - val reader = new RateStreamContinuousReader( + val readSupport = new RateStreamContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() + val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createContinuousReaderFactory(config) assert(tasks.size == 2) - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() + val data = scala.collection.mutable.ListBuffer[InternalRow]() + tasks.foreach { + case t: RateStreamContinuousInputPartition => + val startTimeMs = readSupport.initialOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + val r = readerFactory.createReader(t) + .asInstanceOf[RateStreamContinuousPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index a15a980bb92fd..409156e5ebc70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.execution.streaming.sources -import java.io.IOException -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp -import java.util.Optional import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -33,11 +31,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { @@ -48,14 +48,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread.join() serverThread = null } - if (batchReader != null) { - batchReader.stop() - batchReader = null - } } private var serverThread: ServerThread = null - private var batchReader: MicroBatchReader = null case class AddSocketData(data: String*) extends AddData { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { @@ -64,7 +59,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source } if (sources.isEmpty) { throw new Exception( @@ -90,7 +85,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => throw new IllegalStateException("Could not find socket source") @@ -101,7 +96,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -130,7 +125,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val socket = spark .readStream .format("socket") @@ -180,16 +175,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map.empty[String, String].asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("port" -> "1234").asJava)) } } @@ -198,7 +193,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReader(Optional.empty(), "", a) + provider.createMicroBatchReadSupport("", a) } } @@ -208,28 +203,19 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("name", StringType) :: StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") - val exception = intercept[AnalysisException] { - provider.createMicroBatchReader( - Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + val exception = intercept[UnsupportedOperationException] { + provider.createMicroBatchReadSupport( + userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) } assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) - } - - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - batchReader = provider.createMicroBatchReader( - Optional.empty(), "", new DataSourceOptions(parameters.asJava)) - } + "socket source does not support user-specified schema")) } test("input row metrics") { serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -256,6 +242,162 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("verify ServerThread only accepts the first connection") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + // we are trying to connect to the server once again which should fail + try { + val socket2 = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + testStream(socket2)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + fail("StreamingQueryException is expected!") + } catch { + case e: StreamingQueryException if e.cause.isInstanceOf[SocketException] => // pass + } + } + } + + test("continuous data") { + serverThread = new ServerThread() + serverThread.start() + + val readSupport = new TextSocketContinuousReadSupport( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "port" -> serverThread.port.toString).asJava)) + + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) + assert(tasks.size == 2) + + val numRecords = 10 + val data = scala.collection.mutable.ListBuffer[Int]() + val offsets = scala.collection.mutable.ListBuffer[Int]() + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) + import org.scalatest.time.SpanSugar._ + failAfter(5 seconds) { + // inject rows, read and check the data and offsets + for (i <- 0 until numRecords) { + serverThread.enqueue(i.toString) + } + tasks.foreach { + case t: TextSocketContinuousInputPartition => + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] + for (i <- 0 until numRecords / 2) { + r.next() + offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset) + data.append(r.get().get(0, DataTypes.StringType).asInstanceOf[String].toInt) + // commit the offsets in the middle and validate if processing continues + if (i == 2) { + commitOffset(t.partitionId, i + 1) + } + } + assert(offsets.toSeq == Range.inclusive(1, 5)) + assert(data.toSeq == Range(t.partitionId, 10, 2)) + offsets.clear() + data.clear() + case _ => throw new IllegalStateException("Unexpected task type") + } + assert(readSupport.startOffset.offsets == List(3, 3)) + readSupport.commit(TextSocketOffset(List(5, 5))) + assert(readSupport.startOffset.offsets == List(5, 5)) + } + + def commitOffset(partition: Int, offset: Int): Unit = { + val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset) + readSupport.commit(TextSocketOffset(offsetsToCommit)) + assert(readSupport.startOffset.offsets == offsetsToCommit) + } + } + + test("continuous data - invalid commit") { + serverThread = new ServerThread() + serverThread.start() + + val readSupport = new TextSocketContinuousReadSupport( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "port" -> serverThread.port.toString).asJava)) + + readSupport.startOffset = TextSocketOffset(List(5, 5)) + assertThrows[IllegalStateException] { + readSupport.commit(TextSocketOffset(List(6, 6))) + } + } + + test("continuous data with timestamp") { + serverThread = new ServerThread() + serverThread.start() + + val readSupport = new TextSocketContinuousReadSupport( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "includeTimestamp" -> "true", + "port" -> serverThread.port.toString).asJava)) + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) + assert(tasks.size == 2) + + val numRecords = 4 + // inject rows, read and check the data and offsets + for (i <- 0 until numRecords) { + serverThread.enqueue(i.toString) + } + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) + tasks.foreach { + case t: TextSocketContinuousInputPartition => + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] + for (i <- 0 until numRecords / 2) { + r.next() + assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) + .isInstanceOf[(String, Timestamp)]) + } + case _ => throw new IllegalStateException("Unexpected task type") + } + } + + /** + * This class tries to mimic the behavior of netcat, so that we can ensure + * TextSocketStream supports netcat, which only accepts the first connection + * and exits the process when the first connection is closed. + * + * Please refer SPARK-24466 for more details. + */ private class ServerThread extends Thread with Logging { private val serverSocketChannel = ServerSocketChannel.open() serverSocketChannel.bind(new InetSocketAddress(0)) @@ -265,36 +407,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before override def run(): Unit = { try { - while (true) { - val clientSocketChannel = serverSocketChannel.accept() - clientSocketChannel.configureBlocking(false) - clientSocketChannel.socket().setTcpNoDelay(true) - - // Check whether remote client is closed but still send data to this closed socket. - // This happens in DataStreamReader where a source will be created to get the schema. - var remoteIsClosed = false - var cnt = 0 - while (cnt < 3 && !remoteIsClosed) { - if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { - cnt += 1 - Thread.sleep(100) - } else { - remoteIsClosed = true - } - } + val clientSocketChannel = serverSocketChannel.accept() - if (remoteIsClosed) { - logInfo(s"remote client ${clientSocketChannel.socket()} is closed") - } else { - while (true) { - val line = messageQueue.take() + "\n" - clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) - } - } + // Close server socket channel immediately to mimic the behavior that + // only first connection will be made and deny any further connections + // Note that the first client socket channel will be available + serverSocketChannel.close() + + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + + while (true) { + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) } } catch { case e: InterruptedException => } finally { + // no harm to call close() again... serverSocketChannel.close() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala new file mode 100644 index 0000000000000..dec30fd01f7e2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types._ + + +class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { + + import testImplicits._ + import FlatMapGroupsWithStateExecHelper._ + + // ============================ StateManagerImplV1 ============================ + + test(s"StateManager v1 - primitive type - without timestamp") { + val schema = new StructType().add("value", IntegerType, nullable = false) + testStateManagerWithoutTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - primitive type - with timestamp") { + val schema = new StructType() + .add("value", IntegerType, nullable = false) + .add("timeoutTimestamp", IntegerType, nullable = false) + testStateManagerWithTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null state + intercept[Exception] { + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) + } + } + + test(s"StateManager v1 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null state + intercept[Exception] { + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) + } + } + + // ============================ StateManagerImplV2 ============================ + + test(s"StateManager v2 - primitive type - without timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + testStateManagerWithoutTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - primitive type - with timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + .add("timeoutTimestamp", LongType, nullable = false) + testStateManagerWithTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithoutTimestamp[NestedStruct](version = 2, schema, testValues) + } + + test(s"StateManager v2 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithTimestamp[NestedStruct](version = 2, schema, testValues) + } + + + def testStateManagerWithoutTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = false) + assert(stateManager.stateSchema === expectedStateSchema) + testStateManager(stateManager, testValues, NO_TIMESTAMP) + } + + def testStateManagerWithTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = true) + assert(stateManager.stateSchema === expectedStateSchema) + for (timestamp <- Seq(NO_TIMESTAMP, 1000)) { + testStateManager(stateManager, testValues, timestamp) + } + } + + private def testStateManager[T: Encoder]( + stateManager: StateManager, + values: Seq[T], + timestamp: Long): Unit = { + val keys = (1 to values.size).map(_ => newKey()) + val store = new MemoryStateStore() + + // Test stateManager.getState(), putState(), removeState() + keys.zip(values).foreach { case (key, value) => + try { + stateManager.putState(store, key, value, timestamp) + val data = stateManager.getState(store, key) + assert(data.stateObj == value) + assert(data.timeoutTimestamp === timestamp) + stateManager.removeState(store, key) + assert(stateManager.getState(store, key).stateObj == null) + } catch { + case e: Throwable => + fail(s"put/get/remove test with '$value' failed", e) + } + } + + // Test stateManager.getAllState() + for (i <- keys.indices) { + stateManager.putState(store, keys(i), values(i), timestamp) + } + val allData = stateManager.getAllState(store).map(_.copy()).toArray + assert(allData.map(_.timeoutTimestamp).toSet == Set(timestamp)) + assert(allData.map(_.stateObj).toSet == values.toSet) + } + + private def newStateManager[T: Encoder](version: Int, withTimestamp: Boolean): StateManager = { + FlatMapGroupsWithStateExecHelper.createStateManager( + implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]], + withTimestamp, + version) + } + + private val proj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val keyCounter = new AtomicInteger(0) + private def newKey(): UnsafeRow = { + proj.apply(new GenericInternalRow(Array[Any](keyCounter.getAndDecrement()))).copy() + } +} + +case class Struct(d: Double, str: String) +case class NestedStruct(i: Int, nested: Struct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala new file mode 100644 index 0000000000000..98586d6492c9e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class MemoryStateStore extends StateStore() { + import scala.collection.JavaConverters._ + private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] + + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } + } + + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key.copy(), newValue.copy()) + + override def remove(key: UnsafeRow): Unit = map.remove(key) + + override def commit(): Long = version + 1 + + override def abort(): Unit = {} + + override def id: StateStoreId = null + + override def version: Long = 0 + + override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) + + override def hasCommitted: Boolean = true +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 65b39f0fbd73d..579a364ebc3e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -55,7 +55,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) @@ -73,7 +73,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString def makeStoreRDD( spark: SparkSession, @@ -101,7 +101,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("usage with iterators - only gets and only puts") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 // Returns an iterator of the incremented value made into the store @@ -149,7 +149,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn quietly { val queryRunId = UUID.randomUUID val opId = 0 - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext @@ -189,7 +189,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) .getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 73f8705060402..5e973145b0a37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util import java.util.UUID import scala.collection.JavaConverters._ @@ -47,6 +48,7 @@ import org.apache.spark.util.Utils class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ import StateStoreTestsHelper._ @@ -64,21 +66,143 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] require(!StateStore.isMaintenanceRunning) } + def updateVersionTo( + provider: StateStoreProvider, + currentVersion: Int, + targetVersion: Int): Int = { + var newCurrentVersion = currentVersion + for (i <- newCurrentVersion until targetVersion) { + newCurrentVersion = incrementVersion(provider, i) + } + require(newCurrentVersion === targetVersion) + newCurrentVersion + } + + def incrementVersion(provider: StateStoreProvider, currentVersion: Int): Int = { + val store = provider.getStore(currentVersion) + put(store, "a", currentVersion + 1) + store.commit() + currentVersion + 1 + } + + def checkLoadedVersions( + loadedMaps: util.SortedMap[Long, ProviderMapType], + count: Int, + earliestKey: Long, + latestKey: Long): Unit = { + assert(loadedMaps.size() === count) + assert(loadedMaps.firstKey() === earliestKey) + assert(loadedMaps.lastKey() === latestKey) + } + + def checkVersion( + loadedMaps: util.SortedMap[Long, ProviderMapType], + version: Long, + expectedData: Map[String, Int]): Unit = { + + val originValueMap = loadedMaps.get(version).asScala.map { entry => + rowToString(entry._1) -> rowToInt(entry._2) + }.toMap + + assert(originValueMap === expectedData) + } + + test("retaining only two latest versions when MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 2) + + var currentVersion = 0 + + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 2 : cache will have two elements + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 2, latestKey = 1) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 3 : cache has already two elements and adding ver 3 incurs exceeding cache, + // and ver 3 will be added but ver 1 will be evicted + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 3)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 3, latestKey = 2) + checkVersion(loadedMaps, 3, Map("a" -> 3)) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + } + + test("failure after committing with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 1") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 1) + + var currentVersion = 0 + + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 2 : cache has already one elements and adding ver 2 incurs exceeding cache, + // and ver 2 will be added but ver 1 will be evicted + // this fact ensures cache miss will occur when this partition succeeds commit + // but there's a failure afterwards so have to reprocess previous batch + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + + // suppose there has been failure after committing, and it decided to reprocess previous batch + currentVersion = 1 + + // committing to existing version which is committed partially but abandoned globally + val store = provider.getStore(currentVersion) + // negative value to represent reprocessing + put(store, "a", -2) + store.commit() + currentVersion += 1 + + // make sure newly committed version is reflected to the cache (overwritten) + assert(getData(provider) === Set("a" -> -2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map("a" -> -2)) + } + + test("no cache data with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 0") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 0) + + var currentVersion = 0 + + // commit the ver 1 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + + // commit the ver 2 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + } + test("snapshotting") { val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) var currentVersion = 0 - def updateVersionTo(targetVersion: Int): Unit = { - for (i <- currentVersion + 1 to targetVersion) { - val store = provider.getStore(currentVersion) - put(store, "a", i) - store.commit() - currentVersion += 1 - } - require(currentVersion === targetVersion) - } - updateVersionTo(2) + currentVersion = updateVersionTo(provider, currentVersion, 2) require(getData(provider) === Set("a" -> 2)) provider.doMaintenance() // should not generate snapshot files assert(getData(provider) === Set("a" -> 2)) @@ -89,7 +213,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } // After version 6, snapshotting should generate one snapshot file - updateVersionTo(6) + currentVersion = updateVersionTo(provider, currentVersion, 6) require(getData(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files @@ -104,7 +228,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files - updateVersionTo(20) + currentVersion = updateVersionTo(provider, currentVersion, 20) require(getData(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot @@ -193,6 +317,22 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) } + test("reports memory usage on current version") { + def getSizeOfStateForCurrentVersion(metrics: StateStoreMetrics): Long = { + val metricPair = metrics.customMetrics.find(_._1.name == "stateOnCurrentVersionSizeBytes") + assert(metricPair.isDefined) + metricPair.get._2 + } + + val provider = newStoreProvider() + val store = provider.getStore(0) + val noDataMemoryUsed = getSizeOfStateForCurrentVersion(store.metrics) + + put(store, "a", 1) + store.commit() + assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed) + } + test("StateStore.get") { quietly { val dir = newDir() @@ -507,6 +647,90 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) } + test("expose metrics with custom metrics to StateStoreMetrics") { + def getCustomMetric(metrics: StateStoreMetrics, name: String): Long = { + val metricPair = metrics.customMetrics.find(_._1.name == name) + assert(metricPair.isDefined) + metricPair.get._2 + } + + def getLoadedMapSizeMetric(metrics: StateStoreMetrics): Long = { + metrics.memoryUsedBytes + } + + def assertCacheHitAndMiss( + metrics: StateStoreMetrics, + expectedCacheHitCount: Long, + expectedCacheMissCount: Long): Unit = { + val cacheHitCount = getCustomMetric(metrics, "loadedMapCacheHitCount") + val cacheMissCount = getCustomMetric(metrics, "loadedMapCacheMissCount") + assert(cacheHitCount === expectedCacheHitCount) + assert(cacheMissCount === expectedCacheMissCount) + } + + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + + assert(store.metrics.numKeys === 0) + + val initialLoadedMapSize = getLoadedMapSizeMetric(store.metrics) + assert(initialLoadedMapSize >= 0) + assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) + + put(store, "a", 1) + assert(store.metrics.numKeys === 1) + + put(store, "b", 2) + put(store, "aa", 3) + assert(store.metrics.numKeys === 3) + remove(store, _.startsWith("a")) + assert(store.metrics.numKeys === 1) + assert(store.commit() === 1) + + assert(store.hasCommitted) + + val loadedMapSizeForVersion1 = getLoadedMapSizeMetric(store.metrics) + assert(loadedMapSizeForVersion1 > initialLoadedMapSize) + assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) + + val storeV2 = provider.getStore(1) + assert(!storeV2.hasCommitted) + assert(storeV2.metrics.numKeys === 1) + + put(storeV2, "cc", 4) + assert(storeV2.metrics.numKeys === 2) + assert(storeV2.commit() === 2) + + assert(storeV2.hasCommitted) + + val loadedMapSizeForVersion1And2 = getLoadedMapSizeMetric(storeV2.metrics) + assert(loadedMapSizeForVersion1And2 > loadedMapSizeForVersion1) + assertCacheHitAndMiss(storeV2.metrics, expectedCacheHitCount = 1, expectedCacheMissCount = 0) + + val reloadedProvider = newStoreProvider(store.id) + // intended to load version 2 instead of 1 + // version 2 will not be loaded to the cache in provider + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.metrics.numKeys === 1) + + assert(getLoadedMapSizeMetric(reloadedStore.metrics) === loadedMapSizeForVersion1) + assertCacheHitAndMiss(reloadedStore.metrics, expectedCacheHitCount = 0, + expectedCacheMissCount = 1) + + // now we are loading version 2 + val reloadedStoreV2 = reloadedProvider.getStore(2) + assert(reloadedStoreV2.metrics.numKeys === 2) + + assert(getLoadedMapSizeMetric(reloadedStoreV2.metrics) > loadedMapSizeForVersion1) + assertCacheHitAndMiss(reloadedStoreV2.metrics, expectedCacheHitCount = 0, + expectedCacheMissCount = 2) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } @@ -535,9 +759,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] partition: Int, dir: String = newDir(), minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + numOfVersToRetainInMemory: Int = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get, hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, numOfVersToRetainInMemory) sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val provider = new HDFSBackedStateStoreProvider() provider.init( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala new file mode 100644 index 0000000000000..daacdfd58c7b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class StreamingAggregationStateManagerSuite extends StreamTest { + // ============================ fields and method for test data ============================ + + val testKeys: Seq[String] = Seq("key1", "key2") + val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)") + + val testOutputSchema: StructType = StructType( + testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) + + val testOutputAttributes: Seq[Attribute] = testOutputSchema.toAttributes + val testKeyAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testKeys.contains(p.name) + } + val testValuesAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testValues.contains(p.name) + } + val expectedTestValuesSchema: StructType = testValuesAttributes.toStructType + + val testRow: UnsafeRow = { + val unsafeRowProjection = UnsafeProjection.create(testOutputSchema) + val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema)) + (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } + row + } + + val expectedTestKeyRow: UnsafeRow = { + val keyProjector = GenerateUnsafeProjection.generate(testKeyAttributes, testOutputAttributes) + keyProjector(testRow) + } + + val expectedTestValueRowForV2: UnsafeRow = { + val valueProjector = GenerateUnsafeProjection.generate(testValuesAttributes, + testOutputAttributes) + valueProjector(testRow) + } + + private def createIntegerField(name: String): StructField = { + StructField(name, IntegerType, nullable = false) + } + + // ============================ StateManagerImplV1 ============================ + + test("StateManager v1 - get, put, iter") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 1) + + // in V1, input row is stored as value + testGetPutIterOnStateManager(stateManager, testOutputSchema, testRow, + expectedTestKeyRow, expectedStateValue = testRow) + } + + // ============================ StateManagerImplV2 ============================ + test("StateManager v2 - get, put, iter") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 2) + + // in V2, row for values itself (excluding keys from input row) is stored as value + // so that stored value doesn't have key part, but state manager V2 will provide same output + // as V1 when getting row for key + testGetPutIterOnStateManager(stateManager, expectedTestValuesSchema, testRow, + expectedTestKeyRow, expectedTestValueRowForV2) + } + + private def testGetPutIterOnStateManager( + stateManager: StreamingAggregationStateManager, + expectedValueSchema: StructType, + inputRow: UnsafeRow, + expectedStateKey: UnsafeRow, + expectedStateValue: UnsafeRow): Unit = { + + assert(stateManager.getStateValueSchema === expectedValueSchema) + + val memoryStateStore = new MemoryStateStore() + stateManager.put(memoryStateStore, inputRow) + + assert(memoryStateStore.iterator().size === 1) + assert(stateManager.iterator(memoryStateStore).size === memoryStateStore.iterator().size) + + val keyRow = stateManager.getKey(inputRow) + assert(keyRow === expectedStateKey) + + // iterate state store and verify whether expected format of key and value are stored + val pair = memoryStateStore.iterator().next() + assert(pair.key === keyRow) + assert(pair.value === expectedStateValue) + + // iterate with state manager and see whether original rows are returned as values + val pairFromStateManager = stateManager.iterator(memoryStateStore).next() + assert(pairFromStateManager.key === keyRow) + assert(pairFromStateManager.value === inputRow) + + // following as keys and values + assert(stateManager.keys(memoryStateStore).next() === keyRow) + assert(stateManager.values(memoryStateStore).next() === inputRow) + + // verify the stored value once again via get + assert(memoryStateStore.get(keyRow) === expectedStateValue) + + // state manager should return row which is same as input row regardless of format version + assert(inputRow === stateManager.get(memoryStateStore, keyRow)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index f3f08839c1d3a..02df45d1b7989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -443,7 +443,8 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val oldCount = statusStore.executionsList().size val expectedAccumValue = 12345 - val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) + val expectedAccumValue2 = 54321 + val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan @@ -466,10 +467,14 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val execId = statusStore.executionsList().last.executionId val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") + val driverMetric2 = physicalPlan.metrics("dummy2") val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue) + assert(metrics.contains(driverMetric2.id)) + assert(metrics(driverMetric2.id) === expectedValue2) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -562,20 +567,31 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with * A dummy [[org.apache.spark.sql.execution.SparkPlan]] that updates a [[SQLMetrics]] * on the driver. */ -private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExecNode { +private case class MyPlan(sc: SparkContext, expectedValue: Long, expectedValue2: Long) + extends LeafExecNode { + override def sparkContext: SparkContext = sc override def output: Seq[Attribute] = Seq() override val metrics: Map[String, SQLMetric] = Map( - "dummy" -> SQLMetrics.createMetric(sc, "dummy")) + "dummy" -> SQLMetrics.createMetric(sc, "dummy"), + "dummy2" -> SQLMetrics.createMetric(sc, "dummy2")) override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue + longMetric("dummy2") += expectedValue2 + + // postDriverMetricUpdates may happen multiple time in a query. + // (normally from different operators, but for the sake of testing, from one operator) + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + Seq(metrics("dummy"))) SQLMetrics.postDriverMetricUpdates( sc, sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), - metrics.values.toSeq) + Seq(metrics("dummy2"))) sc.emptyRDD } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index b55489cb2678a..4592a1663faed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -336,7 +336,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) val vector = ArrowUtils.toArrowField("struct", schema, nullable = false, null) - .createVector(allocator).asInstanceOf[NullableMapVector] + .createVector(allocator).asInstanceOf[StructVector] vector.allocateNew() val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] @@ -373,7 +373,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableMapVector] + .createVector(allocator).asInstanceOf[StructVector] vector.allocateNew() val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 772f687526008..f57f07b498261 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1333,4 +1333,11 @@ class ColumnarBatchSuite extends SparkFunSuite { column.close() } + + testVector("WritableColumnVector.reserve(): requested capacity is negative", 1024, ByteType) { + column => + val ex = intercept[RuntimeException] { column.reserve(-1) } + assert(ex.getMessage.contains( + "Cannot reserve additional contiguous bytes in the vectorized reader (integer overflow)")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..5b4736ef4f7f3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.execution.debug.codegenStringSeq +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + pairs.foreach { case (k, v) => + SQLConf.get.setConfString(k, v) + } + try f finally { + pairs.foreach { case (k, _) => + SQLConf.get.unsetConf(k) + } + } + } + + test("ReadOnlySQLConf is correctly created at the executor side") { + withSQLConf("spark.sql.x" -> "a") { + val checks = spark.range(10).mapPartitions { _ => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } + + test("SPARK-24727 CODEGEN_CACHE_MAX_ENTRIES is correctly referenced at the executor side") { + withSQLConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key -> "300") { + val checks = spark.range(10).mapPartitions { _ => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && + conf.getConfString(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key) == "300") + }.collect() + assert(checks.forall(_ == true)) + } + } + + test("SPARK-22219: refactor to control to generate comment") { + Seq(true, false).foreach { flag => + withSQLConf(StaticSQLConf.CODEGEN_COMMENTS.key -> flag.toString) { + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) + assert(res.length == 2) + assert(res.forall { case (_, code) => + (code.contains("* Codegend pipeline") == flag) && + (code.contains("// input[") == flag) + }) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala new file mode 100644 index 0000000000000..bb79d3a84e5a3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{LocalSparkSession, SparkSession} + +class SQLConfGetterSuite extends SparkFunSuite with LocalSparkSession { + + test("SPARK-25076: SQLConf should not be retrieved from a stopped SparkSession") { + spark = SparkSession.builder().master("local").getOrCreate() + assert(SQLConf.get eq spark.sessionState.conf, + "SQLConf.get should get the conf from the active spark session.") + spark.stop() + assert(SQLConf.get eq SQLConf.getFallbackConf, + "SQLConf.get should not get conf from a stopped spark session.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5238adce4a699..7fa0e7fc162ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -24,21 +24,22 @@ import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite +class JDBCSuite extends QueryTest with BeforeAndAfter with PrivateMethodTester with SharedSQLContext { import testImplicits._ @@ -238,6 +239,22 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " + + "AS SELECT 1, 1") + .executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE test.datetime (d DATE, t TIMESTAMP)").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-06', '2018-07-06 05:50:00.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-06', '2018-07-06 08:10:08.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-08', '2018-07-08 13:32:01.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-12', '2018-07-12 09:51:15.0')").executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -255,21 +272,32 @@ class JDBCSuite extends SparkFunSuite s"Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:`$jdbcRelations`") } + private def checkPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is removed in a physical plan and + // the plan only has PhysicalRDD to scan JDBCRelation. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.DataSourceScanExec]) + assert(node.child.asInstanceOf[DataSourceScanExec].nodeName.contains("JDBCRelation")) + df + } + + private def checkNotPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD + // cannot compile given predicates. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.FilterExec]) + df + } + test("SELECT *") { assert(sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - def checkPushdown(df: DataFrame): DataFrame = { - val parentPlan = df.queryExecution.executedPlan - // Check if SparkPlan Filter is removed in a physical plan and - // the plan only has PhysicalRDD to scan JDBCRelation. - assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) - val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] - assert(node.child.isInstanceOf[org.apache.spark.sql.execution.DataSourceScanExec]) - assert(node.child.asInstanceOf[DataSourceScanExec].nodeName.contains("JDBCRelation")) - df - } assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) @@ -302,15 +330,6 @@ class JDBCSuite extends SparkFunSuite "WHERE (THEID > 0 AND TRIM(NAME) = 'mary') OR (NAME = 'fred')") assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2))) - def checkNotPushdown(df: DataFrame): DataFrame = { - val parentPlan = df.queryExecution.executedPlan - // Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD - // cannot compile given predicates. - assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) - val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] - assert(node.child.isInstanceOf[org.apache.spark.sql.execution.FilterExec]) - df - } assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) } @@ -855,19 +874,51 @@ class JDBCSuite extends SparkFunSuite } test("truncate table query by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" - assert(MySQL.getTruncateQuery(table) == defaultQuery) - assert(Postgres.getTruncateQuery(table) == postgresQuery) - assert(db2.getTruncateQuery(table) == defaultQuery) - assert(h2.getTruncateQuery(table) == defaultQuery) - assert(derby.getTruncateQuery(table) == defaultQuery) + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + + assert(postgres.getTruncateQuery(table) == postgresQuery) + assert(oracle.getTruncateQuery(table) == defaultQuery) + assert(teradata.getTruncateQuery(table) == teradataQuery) + } + + test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { + // cascade in a truncate should only be applied for databases that support this, + // even if the parameter is passed. + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" + val oracleQuery = s"TRUNCATE TABLE $table CASCADE" + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery) + assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery) + assert(teradata.getTruncateQuery(table, Some(true)) == teradataQuery) } test("Test DataFrame.where for Date and Timestamp") { @@ -1093,7 +1144,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive - assert(options.table == "t1") + assert(options.tableOrQuery == "t1") // When we convert it to properties, it should be case-sensitive. assert(options.asProperties.size == 3) assert(options.asProperties.get("customkey") == null) @@ -1190,4 +1241,241 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").schema === schema) } } + + test("SPARK-23856 Spark jdbc setQueryTimeout option") { + val numJoins = 100 + val longRunningQuery = + s"SELECT t0.NAME AS c0, ${(1 to numJoins).map(i => s"t$i.NAME AS c$i").mkString(", ")} " + + s"FROM test.people t0 ${(1 to numJoins).map(i => s"join test.people t$i").mkString(" ")}" + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbtable", s"($longRunningQuery)") + .option("queryTimeout", 1) + .load() + val errMsg = intercept[SparkException] { + df.collect() + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } + + test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") { + def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = { + val df = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PARTITION") + .option("partitionColumn", partColName) + .option("lowerBound", 1) + .option("upperBound", 4) + .option("numPartitions", 3) + .load() + + val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName) + df.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + s"$quotedPrtColName < 2 or $quotedPrtColName is null", + s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3", + s"$quotedPrtColName >= 3")) + } + } + + testJdbcParitionColumn("THEID", "THEID") + testJdbcParitionColumn("\"THEID\"", "THEID") + withSQLConf("spark.sql.caseSensitive" -> "false") { + testJdbcParitionColumn("ThEiD", "THEID") + } + testJdbcParitionColumn("THE ID", "THE ID") + + def testIncorrectJdbcPartitionColumn(partColName: String): Unit = { + val errMsg = intercept[AnalysisException] { + testJdbcParitionColumn(partColName, "THEID") + }.getMessage + assert(errMsg.contains(s"User-defined partition column $partColName not found " + + "in the JDBC relation:")) + } + + testIncorrectJdbcPartitionColumn("NoExistingColumn") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) + } + } + + test("query JDBC option - negative tests") { + val query = "SELECT * FROM test.people WHERE theid = 1" + // load path + val e1 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .option("dbtable", "test.people") + .load() + }.getMessage + assert(e1.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + // jdbc api path + val properties = new Properties() + properties.setProperty(JDBCOptions.JDBC_QUERY_STRING, query) + val e2 = intercept[RuntimeException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() + }.getMessage + assert(e2.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e3 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', dbtable 'TEST.PEOPLE', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e3.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e4 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", "") + .load() + }.getMessage + assert(e4.contains("Option `query` can not be empty.")) + + // Option query and partitioncolumn are not allowed together. + val expectedErrorMsg = + s""" + |Options 'query' and 'partitionColumn' can not be specified together. + |Please define the query using `dbtable` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + val e5 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e5.contains(expectedErrorMsg)) + } + + test("query JDBC option") { + val query = "SELECT name, theid FROM test.people WHERE theid = 1" + // query option to pass on the query string. + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .load() + checkAnswer( + df, + Row("fred", 1) :: Nil) + + // query option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + checkAnswer( + sql("select name, theid from queryOption"), + Row("fred", 1) :: Nil) + } + + test("SPARK-22814 support date/timestamp types in partitionColumn") { + val expectedResult = Seq( + ("2018-07-06", "2018-07-06 05:50:00.0"), + ("2018-07-06", "2018-07-06 08:10:08.0"), + ("2018-07-08", "2018-07-08 13:32:01.0"), + ("2018-07-12", "2018-07-12 09:51:15.0") + ).map { case (date, timestamp) => + Row(Date.valueOf(date), Timestamp.valueOf(timestamp)) + } + + // DateType partition column + val df1 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.DATETIME") + .option("partitionColumn", "d") + .option("lowerBound", "2018-07-06") + .option("upperBound", "2018-07-20") + .option("numPartitions", 3) + .load() + + df1.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"D" < '2018-07-10' or "D" is null""", + """"D" >= '2018-07-10' AND "D" < '2018-07-14'""", + """"D" >= '2018-07-14'""")) + } + checkAnswer(df1, expectedResult) + + // TimestampType partition column + val df2 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.DATETIME") + .option("partitionColumn", "t") + .option("lowerBound", "2018-07-04 03:30:00.0") + .option("upperBound", "2018-07-27 14:11:05.0") + .option("numPartitions", 2) + .load() + + df2.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"T" < '2018-07-15 20:50:32.5' or "T" is null""", + """"T" >= '2018-07-15 20:50:32.5'""")) + } + checkAnswer(df2, expectedResult) + } + + test("throws an exception for unsupported partition column types") { + val errMsg = intercept[AnalysisException] { + spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PEOPLE") + .option("partitionColumn", "name") + .option("lowerBound", "aaa") + .option("upperBound", "zzz") + .option("numPartitions", 2) + .load() + }.getMessage + assert(errMsg.contains( + "Partition column type should be numeric, date, or timestamp, but string found.")) + } + + test("SPARK-24288: Enable preventing predicate pushdown") { + val table = "test.people" + + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbTable", table) + .option("pushDownPredicate", false) + .load() + .filter("theid = 1") + .select("name", "theid") + checkAnswer( + checkNotPushdown(df), + Row("fred", 1) :: Nil) + + // pushDownPredicate option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW predicateOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$urlWithUserAndPass', dbTable '$table', pushDownPredicate 'false') + """.stripMargin.replaceAll("\n", " ")) + checkAnswer( + checkNotPushdown(sql("SELECT name, theid FROM predicateOption WHERE theid = 1")), + Row("fred", 1) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1985b1dc82879..b751ec2de4825 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -293,13 +293,23 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("save errors if dbtable is not specified") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val e = intercept[RuntimeException] { + val e1 = intercept[RuntimeException] { df.write.format("jdbc") .option("url", url1) .options(properties.asScala) .save() }.getMessage - assert(e.contains("Option 'dbtable' is required")) + assert(e1.contains("Option 'dbtable' or 'query' is required")) + + val e2 = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .option("query", "select * from TEST.SAVETEST") + .save() + }.getMessage + val msg = "Option 'dbtable' is required. Option 'query' is not applicable while writing." + assert(e2.contains(msg)) } test("save errors if wrong user/password combination") { @@ -515,4 +525,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { }.getMessage assert(e.contains("NULL not allowed for column \"NAME\"")) } + + ignore("SPARK-23856 Spark jdbc setQueryTimeout option") { + // The behaviour of the option `queryTimeout` depends on how JDBC drivers implement the API + // `setQueryTimeout`. For example, in the h2 JDBC driver, `executeBatch` invokes multiple + // INSERT queries in a batch and `setQueryTimeout` means that the driver checks the timeout + // of each query. In the PostgreSQL JDBC driver, `setQueryTimeout` means that the driver + // checks the timeout of an entire batch in a driver side. So, the test below fails because + // this test suite depends on the h2 JDBC driver and the JDBC write path internally + // uses `executeBatch`. + val errMsg = intercept[SparkException] { + spark.range(10000000L).selectExpr("id AS k", "id AS v").coalesce(1).write + .mode(SaveMode.Overwrite) + .option("queryTimeout", 1) + .option("batchsize", Int.MaxValue) + .jdbc(url1, "TEST.TIMEOUTTEST", properties) + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index fb61fa716b946..a9414200e70f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -22,10 +22,11 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ @@ -52,6 +53,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + // number of buckets that doesn't yield empty buckets when bucketing on column j on df/nullDF + // empty buckets before filtering might hide bugs in pruning logic + private val NumBucketsForPruningDF = 7 + private val NumBucketsForPruningNullDf = 5 + test("read bucketed data") { withTable("bucketed_table") { df.write @@ -90,32 +96,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column assert(bucketColumnNames.length == 1) val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) - val matchedBuckets = new BitSet(numBuckets) - bucketValues.foreach { value => - matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) - } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) assert(rdd.isDefined, plan) - val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => - if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator() + // if nothing should be pruned, skip the pruning test + if (bucketValues.nonEmpty) { + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(BucketingUtils.getBucketIdFromValue(bucketColumn, numBuckets, value)) + } + val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of partitions that should have been pruned and are not empty + if (!matchedBuckets.get(index % numBuckets) && iter.nonEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (invalidBuckets.nonEmpty) { + fail(s"Buckets ${invalidBuckets.mkString(",")} should have been pruned from:\n$plan") + } } - // TODO: These tests are not testing the right columns. -// // checking if all the pruned buckets are empty -// val invalidBuckets = checkedResult.collect().toList -// if (invalidBuckets.nonEmpty) { -// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") -// } checkAnswer( bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), @@ -125,7 +136,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -155,13 +166,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = Seq(j, j + 1, j + 2, j + 3), filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), df) + + // Case 4: InSet + val inSetExpr = expressions.InSet($"j".expr, Set(j, j + 1, j + 2, j + 3).map(lit(_).expr)) + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1, j + 2, j + 3), + filterCondition = Column(inSetExpr), + df) } } } test("read non-partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -181,7 +200,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having null in bucketing key") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningNullDf val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here nullDF.write @@ -208,7 +227,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having composite filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -229,7 +248,62 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = j :: Nil, filterCondition = $"j" === j && $"i" > j % 5, df) + + // check multiple bucket values OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1), + filterCondition = $"j" === j || $"j" === (j + 1), + df) + + // check bucket value and none bucket value OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Nil, + filterCondition = $"j" === j || $"i" === 0, + df) + + // check AND condition in complex expression + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j), + filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j, + df) + } + } + } + + test("read bucketed table without filters") { + withTable("bucketed_table") { + val numBuckets = NumBucketsForPruningDF + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") + val plan = bucketedDataFrame.queryExecution.executedPlan + val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) + assert(rdd.isDefined, plan) + + val emptyBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of empty partitions + if (iter.isEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (emptyBuckets.nonEmpty) { + fail(s"Buckets ${emptyBuckets.mkString(",")} should not have been pruned from:\n$plan") } + + checkAnswer( + bucketedDataFrame.orderBy("i", "j", "k"), + df.orderBy("i", "j", "k")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 93f3efe2ccc4a..fc61050dc7458 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.sources import java.io.File -import java.net.URI import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils @@ -48,19 +49,46 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) } - test("numBuckets be greater than 0 but less than 100000") { + test("numBuckets be greater than 0 but less/eq than default bucketing.maxBuckets (100000)") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - Seq(-1, 0, 100000).foreach(numBuckets => { + Seq(-1, 0, 100001).foreach(numBuckets => { val e = intercept[AnalysisException](df.write.bucketBy(numBuckets, "i").saveAsTable("tt")) assert( - e.getMessage.contains("Number of buckets should be greater than 0 but less than 100000")) + e.getMessage.contains("Number of buckets should be greater than 0 but less than")) }) } + test("numBuckets be greater than 0 but less/eq than overridden bucketing.maxBuckets (200000)") { + val maxNrBuckets: Int = 200000 + val catalog = spark.sessionState.catalog + + withSQLConf("spark.sql.sources.bucketing.maxBuckets" -> maxNrBuckets.toString) { + // within the new limit + Seq(100001, maxNrBuckets).foreach(numBuckets => { + withTable("t") { + df.write.bucketBy(numBuckets, "i").saveAsTable("t") + val table = catalog.getTableMetadata(TableIdentifier("t")) + assert(table.bucketSpec == Option(BucketSpec(numBuckets, Seq("i"), Seq()))) + } + }) + + // over the new limit + withTable("t") { + val e = intercept[AnalysisException]( + df.write.bucketBy(maxNrBuckets + 1, "i").saveAsTable("t")) + assert( + e.getMessage.contains("Number of buckets should be greater than 0 but less than")) + } + } + } + test("specify sorting columns without bucketing columns") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + val e = intercept[AnalysisException] { + df.write.sortBy("j").saveAsTable("tt") + } + assert(e.getMessage == "sortBy must be used together with bucketBy;") } test("sorting by non-orderable column") { @@ -74,7 +102,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").parquet("/tmp/path") } - assert(e.getMessage == "'save' does not support bucketing right now;") + assert(e.getMessage == "'save' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using save()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").parquet("/tmp/path") + } + assert(e.getMessage == "'save' does not support bucketBy and sortBy right now;") } test("write bucketed data using insertInto()") { @@ -83,7 +120,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").insertInto("tt") } - assert(e.getMessage == "'insertInto' does not support bucketing right now;") + assert(e.getMessage == "'insertInto' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using insertInto()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").insertInto("tt") + } + assert(e.getMessage == "'insertInto' does not support bucketBy and sortBy right now;") } private lazy val df = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 916a01ee0ca8e..d46029e84433c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -225,7 +225,7 @@ class CreateTableAsSelectSuite test("create table using as select - with invalid number of buckets") { withTable("t") { - Seq(0, 100000).foreach(numBuckets => { + Seq(0, 100001).foreach(numBuckets => { val e = intercept[AnalysisException] { sql( s""" @@ -236,11 +236,42 @@ class CreateTableAsSelectSuite """.stripMargin ) }.getMessage - assert(e.contains("Number of buckets should be greater than 0 but less than 100000")) + assert(e.contains("Number of buckets should be greater than 0 but less than")) }) } } + test("create table using as select - with overriden max number of buckets") { + def createTableSql(numBuckets: Int): String = + s""" + |CREATE TABLE t USING PARQUET + |OPTIONS (PATH '${path.toURI}') + |CLUSTERED BY (a) SORTED BY (b) INTO $numBuckets BUCKETS + |AS SELECT 1 AS a, 2 AS b + """.stripMargin + + val maxNrBuckets: Int = 200000 + val catalog = spark.sessionState.catalog + withSQLConf("spark.sql.sources.bucketing.maxBuckets" -> maxNrBuckets.toString) { + + // Within the new limit + Seq(100001, maxNrBuckets).foreach(numBuckets => { + withTable("t") { + sql(createTableSql(numBuckets)) + val table = catalog.getTableMetadata(TableIdentifier("t")) + assert(table.bucketSpec == Option(BucketSpec(numBuckets, Seq("a"), Seq("b")))) + } + }) + + // Over the new limit + withTable("t") { + val e = intercept[AnalysisException](sql(createTableSql(maxNrBuckets + 1))) + assert( + e.getMessage.contains("Number of buckets should be greater than 0 but less than ")) + } + } + } + test("SPARK-17409: CTAS of decimal calculation") { withTable("tab2") { withTempView("tab1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index fef01c860db6e..0b6d93975daef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,12 +20,36 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class SimpleInsertSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + SimpleInsert(schema)(sqlContext.sparkSession) + } +} + +case class SimpleInsert(userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession) + extends BaseRelation with InsertableRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = userSpecifiedSchema + + override def insert(input: DataFrame, overwrite: Boolean): Unit = { + input.collect + } +} + class InsertSuite extends DataSourceTest with SharedSQLContext { import testImplicits._ @@ -520,4 +544,49 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } } + + test("SPARK-24860: dynamic partition overwrite specified per source without catalog table") { + withTempPath { path => + Seq((1, 1), (2, 2)).toDF("i", "part") + .write.partitionBy("part") + .parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1) :: Row(2, 2) :: Nil) + + Seq((1, 2), (1, 3)).toDF("i", "part") + .write.partitionBy("part").mode("overwrite") + .option("partitionOverwriteMode", "dynamic").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), + Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + + Seq((1, 2), (1, 3)).toDF("i", "part") + .write.partitionBy("part").mode("overwrite") + .option("partitionOverwriteMode", "static").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 2) :: Row(1, 3) :: Nil) + } + } + + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { + withTable("test_table") { + val schema = new StructType() + .add("i", LongType, false) + .add("s", StringType, false) + val newTable = CatalogTable( + identifier = TableIdentifier("test_table", None), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty), + schema = schema, + provider = Some(classOf[SimpleInsertSource].getName)) + + spark.sessionState.catalog.createTable(newTable, false) + + sql("INSERT INTO TABLE test_table SELECT 1, 'a'") + sql("INSERT INTO TABLE test_table SELECT 2, null") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 4adbff5c663bc..0aa67bf1b0d48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -76,20 +76,28 @@ class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat]) } - test("error message for unknown data sources") { - val error1 = intercept[AnalysisException] { - getProvidingClass("avro") + test("avro: show deploy guide for loading the external avro module") { + Seq("avro", "org.apache.spark.sql.avro").foreach { provider => + val message = intercept[AnalysisException] { + getProvidingClass(provider) + }.getMessage + assert(message.contains(s"Failed to find data source: $provider")) + assert(message.contains("Please deploy the application as per the deployment section of")) } - assert(error1.getMessage.contains("Failed to find data source: avro.")) + } - val error2 = intercept[AnalysisException] { - getProvidingClass("com.databricks.spark.avro") - } - assert(error2.getMessage.contains("Failed to find data source: com.databricks.spark.avro.")) + test("kafka: show deploy guide for loading the external kafka module") { + val message = intercept[AnalysisException] { + getProvidingClass("kafka") + }.getMessage + assert(message.contains("Failed to find data source: kafka")) + assert(message.contains("Please deploy the application as per the deployment section of")) + } - val error3 = intercept[ClassNotFoundException] { + test("error message for unknown data sources") { + val error = intercept[ClassNotFoundException] { getProvidingClass("asfdwefasdfasdf") } - assert(error3.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) + assert(error.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 17690e3df9155..13a126ff963d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. class SimpleScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e0a53272cd222..f6c3e0ce82e3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.sources.v2 -import java.util.{ArrayList, List => JList} - import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -38,6 +36,21 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ + private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] + }.head + } + + private def getJavaScanConfig( + query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] + }.head + } + test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -50,18 +63,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - - def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] - }.head - } - Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -70,69 +71,58 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getJavaScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } else { - val reader = getJavaReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getJavaScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } else { - val reader = getJavaReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getJavaScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q4) + val config = getScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q4) + val config = getJavaScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } } } } - test("unsafe row scan implementation") { - Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => - withClue(cls.getName) { - val df = spark.read.format(cls.getName).load() - checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) - } - } - } - test("columnar batch scan implementation") { - Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) @@ -145,8 +135,8 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { - val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("requires a user-supplied schema")) + val e = intercept[IllegalArgumentException](spark.read.format(cls.getName).load()) + assert(e.getMessage.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() @@ -164,25 +154,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('a).agg(sum('b)) + val groupByColA = df.groupBy('i).agg(sum('j)) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(groupByColA.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + val groupByColAB = df.groupBy('i, 'j).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(groupByColAB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('b).agg(sum('a)) + val groupByColB = df.groupBy('j).agg(sum('i)) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(groupByColB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e @@ -203,33 +193,33 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val path = file.getCanonicalPath assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - spark.range(10).select('id, -'id).write.format(cls.getName) + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).select('id, -'id)) // test with different save modes - spark.range(10).select('id, -'id).write.format(cls.getName) + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).union(spark.range(10)).select('id, -'id)) - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("overwrite").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("ignore").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) val e = intercept[Exception] { - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("error").save() } assert(e.getMessage.contains("data already exists")) @@ -246,20 +236,13 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } // this input data will fail to read middle way. - val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) val e2 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } assert(e2.getMessage.contains("Writing job aborted")) // make sure we don't have partial data. assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - - // test internal row writer - spark.range(5).select('id, -'id).write.format(cls.getName) - .option("path", path).option("internal", "true").mode("overwrite").save() - checkAnswer( - spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) } } } @@ -271,7 +254,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) val numPartition = 6 - spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), @@ -290,321 +273,337 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23301: column pruning with arbitrary expressions") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val reader1 = getReader(q1) - assert(reader1.requiredSchema.fieldNames === Seq("i")) + val config1 = getScanConfig(q1) + assert(config1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val reader2 = getReader(q2) - assert(reader2.requiredSchema.isEmpty) + val config2 = getScanConfig(q2) + assert(config2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val reader3 = getReader(q3) - assert(reader3.filters.isEmpty) - assert(reader3.requiredSchema.fieldNames === Seq("j")) + val config3 = getScanConfig(q3) + assert(config3.filters.isEmpty) + assert(config3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val reader4 = getReader(q4) - assert(reader4.requiredSchema.fieldNames === Seq("i")) + val config4 = getScanConfig(q4) + assert(config4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { - def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + def checkCanonicalizedOutput( + df: DataFrame, logicalNumOutput: Int, physicalNumOutput: Int): Unit = { val logical = df.queryExecution.optimizedPlan.collect { case d: DataSourceV2Relation => d }.head - assert(logical.canonicalized.output.length == numOutput) + assert(logical.canonicalized.output.length == logicalNumOutput) val physical = df.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => d }.head - assert(physical.canonicalized.output.length == numOutput) + assert(physical.canonicalized.output.length == physicalNumOutput) } val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - checkCanonicalizedOutput(df, 2) - checkCanonicalizedOutput(df.select('i), 1) + checkCanonicalizedOutput(df, 2, 2) + checkCanonicalizedOutput(df.select('i), 2, 1) } } -class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") +case class RangeInputPartition(start: Int, end: Int) extends InputPartition - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) - } - } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { + override def build(): ScanConfig = this } -class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { +object SimpleReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def next(): Boolean = { + current += 1 + current < end + } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) + override def get(): InternalRow = InternalRow(current, -current) + + override def close(): Unit = {} } } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[Row] - with DataReader[Row] { - private var current = start - 1 +abstract class SimpleReadSupport extends BatchReadSupport { + override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) - - override def next(): Boolean = { - current += 1 - current < end + override def newScanConfigBuilder(): ScanConfigBuilder = { + NoopScanConfigBuilder(fullSchema()) } - override def get(): Row = Row(current, -current) - - override def close(): Unit = {} + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SimpleReaderFactory + } } +class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { -class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5)) + } + } - class Reader extends DataSourceReader - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - var requiredSchema = new StructType().add("i", "int").add("j", "int") - var filters = Array.empty[Filter] +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. +class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } + } - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true - case _ => false - } - this.filters = supported - unsupported - } + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - override def pushedFilters(): Array[Filter] = filters - override def readSchema(): StructType = { - requiredSchema - } +class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { + + class ReadSupport extends SimpleReadSupport { + override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - val lowerBound = filters.collect { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters + + val lowerBound = filters.collectFirst { case GreaterThan("i", v: Int) => v - }.headOption + } - val res = new ArrayList[DataReaderFactory[Row]] + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] if (lowerBound.isEmpty) { - res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 4) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 9) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 10)) } - res + res.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema + new AdvancedReaderFactory(requiredSchema) } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) - extends DataReaderFactory[Row] with DataReader[Row] { +class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - private var current = start - 1 + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] - override def createDataReader(): DataReader[Row] = { - new AdvancedDataReaderFactory(start, end, requiredSchema) + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema } - override def close(): Unit = {} + override def readSchema(): StructType = requiredSchema - override def next(): Boolean = { - current += 1 - current < end - } - - override def get(): Row = { - val values = requiredSchema.map(_.name).map { - case "i" => current - case "j" => -current + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false } - Row.fromSeq(values) + this.filters = supported + unsupported } + + override def pushedFilters(): Array[Filter] = filters + + override def build(): ScanConfig = this } +class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 -class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { + override def next(): Boolean = { + current += 1 + current < end + } - class Reader extends DataSourceReader with SupportsScanUnsafeRow { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def get(): InternalRow = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current + } + InternalRow.fromSeq(values) + } - override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), - new UnsafeRowDataReaderFactory(5, 10)) + override def close(): Unit = {} } } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class UnsafeRowDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { - private val row = new UnsafeRow(2) - row.pointTo(new Array[Byte](8 * 3), 8 * 3) +class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { - private var current = start - 1 + class ReadSupport(val schema: StructType) extends SimpleReadSupport { + override def fullSchema(): StructType = schema - override def createDataReader(): DataReader[UnsafeRow] = this - - override def next(): Boolean = { - current += 1 - current < end - } - override def get(): UnsafeRow = { - row.setInt(0, current) - row.setInt(1, -current) - row + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = + Array.empty } - override def close(): Unit = {} -} - -class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - - class Reader(val readSchema: StructType) extends DataSourceReader { - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = - java.util.Collections.emptyList() + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + throw new IllegalArgumentException("requires a user-supplied schema") } - override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = - new Reader(schema) + override def createBatchReadSupport( + schema: StructType, options: DataSourceOptions): BatchReadSupport = { + new ReadSupport(schema) + } } -class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { +class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - class Reader extends DataSourceReader with SupportsScanColumnarBatch { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) + } - override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { - java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + ColumnarReaderFactory } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class BatchDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { - +object ColumnarReaderFactory extends PartitionReaderFactory { private final val BATCH_SIZE = 20 - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch(Array(i, j)) - private var current = start + override def supportColumnarReads(partition: InputPartition): Boolean = true - override def createDataReader(): DataReader[ColumnarBatch] = this + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + throw new UnsupportedOperationException + } - override def next(): Boolean = { - i.reset() - j.reset() + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[ColumnarBatch] { + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch(Array(i, j)) + + private var current = start + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } - var count = 0 - while (current < end && count < BATCH_SIZE) { - i.putInt(count, current) - j.putInt(count, -current) - current += 1 - count += 1 - } + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } - if (count == 0) { - false - } else { - batch.setNumRows(count) - true - } - } + override def get(): ColumnarBatch = batch - override def get(): ColumnarBatch = { - batch + override def close(): Unit = batch.close() + } } - - override def close(): Unit = batch.close() } -class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning { - override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") +class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { // Note that we don't have same value of column `a` across partitions. - java.util.Arrays.asList( - new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) + Array( + SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), + SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) } - override def outputPartitioning(): Partitioning = new MyPartitioning + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SpecificReaderFactory + } + + override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning } class MyPartitioning extends Partitioning { override def numPartitions(): Int = 2 override def satisfy(distribution: Distribution): Boolean = distribution match { - case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case c: ClusteredDistribution => c.clusteredColumns.contains("i") case _ => false } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) - extends DataReaderFactory[Row] - with DataReader[Row] { - assert(i.length == j.length) +case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition - private var current = -1 +object SpecificReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[SpecificInputPartition] + new PartitionReader[InternalRow] { + private var current = -1 - override def createDataReader(): DataReader[Row] = this - - override def next(): Boolean = { - current += 1 - current < i.length - } + override def next(): Boolean = { + current += 1 + current < p.i.length + } - override def get(): Row = Row(i(current), j(current)) + override def get(): InternalRow = InternalRow(p.i(current), p.j(current)) - override def close(): Unit = {} + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a5007fa321359..952241b0b6be5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,34 +18,36 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util.{Collections, List => JList, Optional} +import java.util.Optional import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} +import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration /** * A HDFS based transactional writable data source. - * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. - * Each job moves files from `target/_temporary/jobId/` to `target`. + * Each task writes data to `target/_temporary/queryId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/queryId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { +class SimpleWritableDataSource extends DataSourceV2 + with BatchReadSupportProvider with BatchWriteSupportProvider { private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader { - override def readSchema(): StructType = schema + class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def fullSchema(): StructType = schema + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -53,21 +55,23 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.map { f => - val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVDataReaderFactory( - f.getPath.toUri.toString, - serializableConf): DataReaderFactory[Row] - }.toList.asJava + CSVInputPartitionReader(f.getPath.toUri.toString) + }.toArray } else { - Collections.emptyList() + Array.empty } } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val serializableConf = new SerializableConfiguration(conf) + new CSVReaderFactory(serializableConf) + } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { - override def createWriterFactory(): DataWriterFactory[Row] = { + class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { + override def createBatchWriterFactory(): DataWriterFactory = { SimpleCounter.resetCounter - new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -76,7 +80,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) - val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val jobPath = new Path(new Path(finalPath, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) try { for (file <- fs.listStatus(jobPath).map(_.getPath)) { @@ -91,40 +95,27 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } override def abort(messages: Array[WriterCommitMessage]): Unit = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) + val jobPath = new Path(new Path(path, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) fs.delete(jobPath, true) } } - class InternalRowWriter(jobId: String, path: String, conf: Configuration) - extends Writer(jobId, path, conf) with SupportsWriteInternalRow { - - override def createWriterFactory(): DataWriterFactory[Row] = { - throw new IllegalArgumentException("not expected!") - } - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) - } - } - - override def createReader(options: DataSourceOptions): DataSourceReader = { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration - new Reader(path.toUri.toString, conf) + new ReadSupport(path.toUri.toString, conf) } - override def createWriter( - jobId: String, + override def createBatchWriteSupport( + queryId: String, schema: StructType, mode: SaveMode, - options: DataSourceOptions): Optional[DataSourceWriter] = { + options: DataSourceOptions): Optional[BatchWriteSupport] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) val path = new Path(options.get("path").get()) - val internal = options.get("internal").isPresent val conf = SparkContext.getActive.get.hadoopConfiguration val fs = path.getFileSystem(conf) @@ -142,49 +133,43 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS fs.delete(path, true) } - Optional.of(createWriter(jobId, path, conf, internal)) - } - - private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString - if (internal) { - new InternalRowWriter(jobId, pathStr, conf) - } else { - new Writer(jobId, pathStr, conf) - } + Optional.of(new WritSupport(queryId, pathStr, conf)) } } -class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) - extends DataReaderFactory[Row] with DataReader[Row] { +case class CSVInputPartitionReader(path: String) extends InputPartition - @transient private var lines: Iterator[String] = _ - @transient private var currentLine: String = _ - @transient private var inputStream: FSDataInputStream = _ +class CSVReaderFactory(conf: SerializableConfiguration) + extends PartitionReaderFactory { - override def createDataReader(): DataReader[Row] = { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val path = partition.asInstanceOf[CSVInputPartitionReader].path val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) - inputStream = fs.open(filePath) - lines = new BufferedReader(new InputStreamReader(inputStream)) - .lines().iterator().asScala - this - } - override def next(): Boolean = { - if (lines.hasNext) { - currentLine = lines.next() - true - } else { - false - } - } + new PartitionReader[InternalRow] { + private val inputStream = fs.open(filePath) + private val lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + + private var currentLine: String = _ + + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } - override def get(): Row = Row(currentLine.split(",").map(_.trim.toLong): _*) + override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) - override def close(): Unit = { - inputStream.close() + override def close(): Unit = { + inputStream.close() + } + } } } @@ -204,57 +189,20 @@ private[v2] object SimpleCounter { } } -class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[Row] { - - override def createDataWriter( - partitionId: Int, - attemptNumber: Int, - epochId: Long): DataWriter[Row] = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") - val fs = filePath.getFileSystem(conf.value) - new SimpleCSVDataWriter(fs, filePath) - } -} - -class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { +class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory { - private val out = fs.create(file) - - override def write(record: Row): Unit = { - out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") - } - - override def commit(): WriterCommitMessage = { - out.close() - null - } - - override def abort(): Unit = { - try { - out.close() - } finally { - fs.delete(file, false) - } - } -} - -class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[InternalRow] { - - override def createDataWriter( + override def createWriter( partitionId: Int, - attemptNumber: Int, - epochId: Long): DataWriter[InternalRow] = { + taskId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) - new InternalRowCSVDataWriter(fs, filePath) + new CSVDataWriter(fs, filePath) } } -class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { +class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { private val out = fs.create(file) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index d6bef9ce07379..026af17c7b23f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -18,17 +18,22 @@ package org.apache.spark.sql.streaming import java.{util => ju} +import java.io.File import java.text.SimpleDateFormat import java.util.{Calendar, Date} +import org.apache.commons.io.FileUtils import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.util.Utils class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -122,39 +127,133 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(aggWithWatermark)( AddData(inputData2, 15), CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(15)) - assert(e.get("min") === formatTimestamp(15)) - assert(e.get("avg") === formatTimestamp(15)) - assert(e.get("watermark") === formatTimestamp(0)) - }, + assertEventStats(min = 15, max = 15, avg = 15, wtrmark = 0), AddData(inputData2, 10, 12, 14), CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(14)) - assert(e.get("min") === formatTimestamp(10)) - assert(e.get("avg") === formatTimestamp(12)) - assert(e.get("watermark") === formatTimestamp(5)) - }, - AddData(inputData2, 25), - CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - }, + assertEventStats(min = 10, max = 14, avg = 12, wtrmark = 5), AddData(inputData2, 25), CheckAnswer((10, 3)), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(15)) - } + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5) ) } + test("event time and watermark metrics with Trigger.Once (SPARK-24699)") { + // All event time metrics where watermarking is set + val inputData = MemoryStream[Int] + val aggWithWatermark = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + // Unlike the ProcessingTime trigger, Trigger.Once only runs one trigger every time + // the query is started and it does not run no-data batches. Hence the answer generated + // by the updated watermark is only generated the next time the query is started. + // Also, the data to process in the next trigger is added *before* starting the stream in + // Trigger.Once to ensure that first and only trigger picks up the new data. + + testStream(aggWithWatermark)( + StartStream(Trigger.Once), // to make sure the query is not running when adding data 1st time + awaitTermination(), + + AddData(inputData, 15), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 15, max = 15, avg = 15, wtrmark = 0), + // watermark should be updated to 15 - 10 = 5 + + AddData(inputData, 10, 12, 14), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 10, max = 14, avg = 12, wtrmark = 5), + // watermark should stay at 5 + + AddData(inputData, 25), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5), + // watermark should be updated to 25 - 10 = 15 + + AddData(inputData, 50), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer((10, 3)), // watermark = 15 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 15), + // watermark should be updated to 50 - 10 = 40 + + AddData(inputData, 50), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer((15, 1), (25, 1)), // watermark = 40 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 40)) + } + + test("recovery from Spark ver 2.3.1 commit log without commit metadata (SPARK-24699)") { + // All event time metrics where watermarking is set + val inputData = MemoryStream[Int] + val aggWithWatermark = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(15) + inputData.addData(10, 12, 14) + + testStream(aggWithWatermark)( + /* + + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + + StartStream(checkpointLocation = "./sql/core/src/test/resources/structured-streaming/" + + "checkpoint-version-2.3.1-without-commit-log-metadata/")), + AddData(inputData, 15), // watermark should be updated to 15 - 10 = 5 + CheckAnswer(), + AddData(inputData, 10, 12, 14), // watermark should stay at 5 + CheckAnswer(), + StopStream, + + // Offset log should have watermark recorded as 5. + */ + + StartStream(Trigger.Once), + awaitTermination(), + + AddData(inputData, 25), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5), + // watermark should be updated to 25 - 10 = 15 + + AddData(inputData, 50), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer((10, 3)), // watermark = 15 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 15), + // watermark should be updated to 50 - 10 = 40 + + AddData(inputData, 50), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer((15, 1), (25, 1)), // watermark = 40 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 40)) + } + test("append mode") { val inputData = MemoryStream[Int] @@ -167,15 +266,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckNewAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - assertNumStateRows(3), - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10, 5)), + CheckNewAnswer((10, 5)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -193,15 +289,15 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation, OutputMode.Update)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch((10, 5), (15, 1)), + CheckNewAnswer((10, 5), (15, 1)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch((25, 1)), - assertNumStateRows(3), + CheckNewAnswer((25, 1)), + assertNumStateRows(2), AddData(inputData, 10, 25), // Ignore 10 as its less than watermark - CheckLastBatch((25, 2)), + CheckNewAnswer((25, 2)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -251,56 +347,25 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(df)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - StopStream, - StartStream(), - CheckLastBatch(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckLastBatch((10, 5)), + CheckAnswer((10, 5)), StopStream, AssertOnQuery { q => // purge commit and clear the sink - val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) q.commitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, StartStream(), - CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10 - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), + AddData(inputData, 10, 27, 30), // Advance watermark to 20 seconds, 10 should be ignored + CheckAnswer((15, 1)), StopStream, - StartStream(), // Watermark should still be 15 seconds - AddData(inputData, 17), - CheckLastBatch(), // We still do not see next batch - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), - AddData(inputData, 30), // Evict items less than previous watermark. - CheckLastBatch((15, 2)) // Ensure we see next window - ) - } - - test("dropping old data") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 12), - CheckAnswer(), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckAnswer((10, 3)), - AddData(inputData, 10), // 10 is later than 15 second watermark - CheckAnswer((10, 3)), - AddData(inputData, 25), - CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + StartStream(), + AddData(inputData, 17), // Watermark should still be 20 seconds, 17 should be ignored + CheckAnswer((15, 1)), + AddData(inputData, 40), // Advance watermark to 30 seconds, emit first data 25 + CheckNewAnswer((25, 2)) ) } @@ -421,8 +486,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 10), CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. CheckAnswer((10, 1)) ) } @@ -501,16 +564,183 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + // Check if there is new answer if flag is set, no new answer otherwise + if (flag) CheckNewAnswer((10, 5)) else CheckNewAnswer() + ) + } + + testWithFlag(true) + testWithFlag(false) + } + + test("MultipleWatermarkPolicy: max") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15), // max(20 - 10, 30 - 15) = 15 + StopStream, + StartStream(), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: min") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 10), // min(20 - 10, 30 - 15) = 10 + StopStream, + StartStream(), + checkWatermark(input1, 10), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // min(120 - 10, 130 - 15) = 110, policy recovered correctly + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) // does not advance when only one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from checkpoints ignores session conf") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val checkpointDir = Utils.createTempDir().getCanonicalFile + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15) // max(20 - 10, 30 - 15) = 15 + ) + } + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from Spark ver 2.3.1 checkpoints ensures min policy") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + input1.addData(20) + input2.addData(30) + input1.addData(10) + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + Execute { _.processAllAvailable() }, + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // should calculate 'min' even if session conf has 'max' policy + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) + ) + } + } + + test("MultipleWatermarkPolicy: fail on incorrect conf values") { + val invalidValues = Seq("", "random") + invalidValues.foreach { value => + val e = intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, value) + } + assert(e.getMessage.toLowerCase.contains("valid values are 'min' and 'max'")) + } + } + + private def dfWithMultipleWatermarks( + input1: MemoryStream[Int], + input2: MemoryStream[Int]): Dataset[_] = { + val df1 = input1.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + val df2 = input2.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "15 seconds") + df1.union(df2).select($"eventTime".cast("int")) + } + + private def checkWatermark(input: MemoryStream[Int], watermark: Long) = Execute { q => + input.addData(1) + q.processAllAvailable() + assert(q.lastProgress.eventTime.get("watermark") == formatTimestamp(watermark)) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) true } + /** Assert event stats generated on that last batch with data in it */ private def assertEventStats(body: ju.Map[String, String] => Unit): AssertOnQuery = { - AssertOnQuery { q => + Execute("AssertEventStats") { q => body(q.recentProgress.filter(_.numInputRows > 0).lastOption.get.eventTime) - true + } + } + + /** Assert event stats generated on that last batch with data in it */ + private def assertEventStats(min: Long, max: Long, avg: Double, wtrmark: Long): AssertOnQuery = { + assertEventStats { e => + assert(e.get("min") === formatTimestamp(min), s"min value mismatch") + assert(e.get("max") === formatTimestamp(max), s"max value mismatch") + assert(e.get("avg") === formatTimestamp(avg.toLong), s"avg value mismatch") + assert(e.get("watermark") === formatTimestamp(wtrmark), s"watermark value mismatch") } } @@ -520,4 +750,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche private def formatTimestamp(sec: Long): String = { timestampFormat.format(new ju.Date(sec * 1000)) } + + private def awaitTermination(): AssertOnQuery = Execute("AwaitTermination") { q => + q.awaitTermination() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index cf41d7e0e4fe1..ed53def556cb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -279,13 +279,10 @@ class FileStreamSinkSuite extends StreamTest { check() // nothing emitted yet addTimestamp(104, 123) // watermark = 90 before this, watermark = 123 - 10 = 113 after this - check() // nothing emitted yet + check((100L, 105L) -> 2L) // no-data-batch emits results on 100-105, addTimestamp(140) // wm = 113 before this, emit results on 100-105, wm = 130 after this - check((100L, 105L) -> 2L) - - addTimestamp(150) // wm = 130s before this, emit results on 120-125, wm = 150 after this - check((100L, 105L) -> 2L, (120L, 125L) -> 1L) + check((100L, 105L) -> 2L, (120L, 125L) -> 1L) // no-data-batch emits results on 120-125 } finally { if (query != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index b1416bff87ee7..e77ba1ec9f1eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,24 +17,28 @@ package org.apache.spark.sql.streaming +import java.io.File import java.sql.Date -import java.util.concurrent.ConcurrentHashMap +import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils /** Class to check custom state types */ case class RunningCount(count: Long) @@ -359,13 +363,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -396,7 +400,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -443,6 +447,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -453,6 +469,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = @@ -477,48 +517,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest priorTimeoutTimestamp = priorTimeoutTimestamp, expectedState = Some(5), // state should change expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -590,7 +603,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("flatMapGroupsWithState - streaming") { + testWithAllStateVersions("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -615,20 +628,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -657,19 +670,19 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "1"), ("a", "2"), ("b", "1")), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) + CheckNewAnswer(("a", "1"), ("c", "1")) ) } - test("flatMapGroupsWithState - streaming + aggregation") { + testWithAllStateVersions("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -694,22 +707,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Complete)( AddData(inputData, "a"), - CheckLastBatch(("a", 1)), + CheckNewAnswer(("a", 1)), AddData(inputData, "a", "b"), // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), + CheckNewAnswer(("a", 2), ("b", 1)), StopStream, StartStream(), AddData(inputData, "a", "b"), // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), + CheckNewAnswer(("a", 3), ("b", 2)), StopStream, StartStream(), AddData(inputData, "a", "c"), // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) ) } @@ -728,9 +741,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } - test("flatMapGroupsWithState - streaming with processing time timeout") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + testWithAllStateVersions("flatMapGroupsWithState - streaming with processing time timeout") { + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCannotGetWatermark { state.getCurrentWatermarkMs() } @@ -757,17 +770,17 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "b"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("b", "1")), + CheckNewAnswer(("b", "1")), assertNumStateRows(total = 2, updated = 1), AddData(inputData, "b"), AdvanceManualClock(10 * 1000), - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, @@ -775,38 +788,117 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "c"), AdvanceManualClock(11 * 1000), - CheckLastBatch(("b", "-1"), ("c", "1")), + CheckNewAnswer(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), - AddData(inputData, "c"), - AdvanceManualClock(20 * 1000), - CheckLastBatch(("c", "2")), - assertNumStateRows(total = 1, updated = 1) + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows(total = 0, updated = 0) ) } - test("flatMapGroupsWithState - streaming with event time timeout + watermark") { - // Function to maintain the max event time - // Returns the max event time in the state, or -1 if the state was removed by timeout + testWithAllStateVersions("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } - val timeoutDelay = 5 - if (key != "a") { - Iterator.empty + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) } else { - if (state.hasTimedOut) { - state.remove() - Iterator((key, -1)) - } else { - val valuesSeq = values.toSeq - val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) - val timeoutTimestampMs = maxEventTime + timeoutDelay - state.update(maxEventTime) - state.setTimeoutTimestamp(timeoutTimestampMs * 1000) - Iterator((key, maxEventTime.toInt)) - } + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + + test("flatMapGroupsWithState - uses state format version 2 by default") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { query => + // Verify state format = 2 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 2) + } + ) + } + + test("flatMapGroupsWithState - recovery from checkpoint uses state format version 1") { + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) } } val inputData = MemoryStream[(String, Int)] @@ -818,19 +910,51 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .groupByKey(_._1) .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(("a", 11), ("a", 13), ("a", 15)) + inputData.addData(("a", 4)) + testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second")), - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... - CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + StartStream( + checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckLastBatch(), // No output as data should get filtered by watermark - AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s - CheckLastBatch(), // No output as no data for "a" - AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored - CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + CheckNewAnswer(), // No output as data should get filtered by watermark + */ + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + Execute { query => + // Verify state format = 1 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 1) + }, + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 ) } + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -856,20 +980,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -920,15 +1044,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), + CheckNewAnswer(("a", 1L)), AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), + CheckNewAnswer(("a", 2L)), setFailInTask(true), AddData(inputData, "a"), ExpectFailure[SparkException](), // task should fail but should not increment count setFailInTask(false), StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count ) } @@ -938,7 +1062,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch("a"), + CheckNewAnswer("a"), AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) ) } @@ -1000,7 +1124,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, ("a", 1L)), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")) + CheckNewAnswer(("a", "1")) ) } } @@ -1020,7 +1144,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -1042,7 +1166,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -1069,21 +1193,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -1094,15 +1217,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } @@ -1110,6 +1229,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { + val stateFormatVersion = spark.conf.get(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val emptyRdd = spark.sparkContext.emptyRDD[InternalRow] MemoryStream[Int] .toDS .groupByKey(x => x) @@ -1117,8 +1238,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, m, t, - Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) + f, k, v, g, d, o, None, s, stateFormatVersion, m, t, + Some(currentBatchTimestamp), Some(currentBatchWatermark), + RDDScanExec(g, emptyRdd, "rdd")) }.get } @@ -1150,33 +1272,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } def rowToInt(row: UnsafeRow): Int = row.getInt(0) + + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { + for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { + test(s"$name - state format version $version") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> version.toString) { + func + } + } + } + } } object FlatMapGroupsWithStateSuite { var failInTask = true - class MemoryStateStore extends StateStore() { - import scala.collection.JavaConverters._ - private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - - override def iterator(): Iterator[UnsafeRowPair] = { - map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } - } - - override def get(key: UnsafeRow): UnsafeRow = map.get(key) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { - map.put(key.copy(), newValue.copy()) - } - override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def commit(): Long = version + 1 - override def abort(): Unit = { } - override def id: StateStoreId = null - override def version: Long = 0 - override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) - override def hasCommitted: Boolean = true - } - def assertCanGetProcessingTime(predicate: => Boolean): Unit = { if (!predicate) throw new TestFailedException("Could not get processing time", 20) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 368c4604dfca8..fb5d13d09fb0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,20 +17,62 @@ package org.apache.spark.sql.streaming +import org.apache.spark.sql.execution.streaming.StreamExecution + trait StateStoreMetricsTest extends StreamTest { + private var lastCheckedRecentProgressIndex = -1 + private var lastQuery: StreamExecution = null + + override def beforeEach(): Unit = { + super.beforeEach() + lastCheckedRecentProgressIndex = -1 + } + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert( - progressWithData.stateOperators.map(_.numRowsTotal) === total, - "incorrect total rows") - assert( - progressWithData.stateOperators.map(_.numRowsUpdated) === updated, - "incorrect updates rows") + // This assumes that the streaming query will not make any progress while the eventually + // is being executed. + eventually(timeout(streamingTimeout)) { + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") + + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q + + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) + + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") + + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + + lastCheckedRecentProgressIndex = recentProgress.length - 1 + } true } def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = assertNumStateRows(Seq(total), Seq(updated)) + + def arraySum(arraySeq: Seq[Array[Long]], arrayLength: Int): Seq[Long] = { + if (arraySeq.isEmpty) return Seq.fill(arrayLength)(0L) + + assert(arraySeq.forall(_.length == arrayLength), + "Arrays are of different lengths:\n" + arraySeq.map(_.toSeq).mkString("\n")) + (0 until arrayLength).map { index => arraySeq.map(_.apply(index)).sum } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c1ec1eba69fb2..bf509b1976ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -27,6 +27,7 @@ import scala.util.control.ControlThrowable import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration +import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} @@ -35,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -513,6 +515,120 @@ class StreamSuite extends StreamTest { } } + test("explain-continuous") { + val inputData = ContinuousMemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test `df.explain` + val explain = ExplainCommand(df.queryExecution.logical, extended = false) + val explainString = + spark.sessionState + .executePlan(explain) + .executedPlan + .executeCollect() + .map(_.getString(0)) + .mkString("\n") + assert(explainString.contains("Filter")) + assert(explainString.contains("MapElements")) + assert(!explainString.contains("LocalTableScan")) + + // Test StreamingQuery.display + val q = df.writeStream.queryName("memory_continuous_explain") + .outputMode(OutputMode.Update()).format("memory") + .trigger(Trigger.Continuous("1 seconds")) + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + try { + // in continuous mode, the query will be run even there's no data + // sleep a bit to ensure initialization + eventually(timeout(2.seconds), interval(100.milliseconds)) { + assert(q.lastExecution != null) + } + + val explainWithoutExtended = q.explainInternal(false) + + // `extended = false` only displays the physical plan. + assert("Streaming RelationV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithoutExtended).size === 0) + assert("ScanV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithoutExtended).size === 1) + + val explainWithExtended = q.explainInternal(true) + // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical + // plan. + assert("Streaming RelationV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithExtended).size === 3) + assert("ScanV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithExtended).size === 1) + } finally { + q.stop() + } + } + + test("codegen-microbatch") { + val inputData = MemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test StreamingQuery.codegen + val q = df.writeStream.queryName("memory_microbatch_codegen") + .outputMode(OutputMode.Update) + .format("memory") + .trigger(Trigger.ProcessingTime("1 seconds")) + .start() + + try { + import org.apache.spark.sql.execution.debug._ + assert("No physical plan. Waiting for data." === codegenString(q)) + assert(codegenStringSeq(q).isEmpty) + + inputData.addData(1, 2, 3, 4, 5) + q.processAllAvailable() + + assertDebugCodegenResult(q) + } finally { + q.stop() + } + } + + test("codegen-continuous") { + val inputData = ContinuousMemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test StreamingQuery.codegen + val q = df.writeStream.queryName("memory_continuous_codegen") + .outputMode(OutputMode.Update) + .format("memory") + .trigger(Trigger.Continuous("1 seconds")) + .start() + + try { + // in continuous mode, the query will be run even there's no data + // sleep a bit to ensure initialization + eventually(timeout(2.seconds), interval(100.milliseconds)) { + assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution != null) + } + + assertDebugCodegenResult(q) + } finally { + q.stop() + } + } + + private def assertDebugCodegenResult(query: StreamingQuery): Unit = { + import org.apache.spark.sql.execution.debug._ + + val codegenStr = codegenString(query) + assert(codegenStr.contains("Found 1 WholeStageCodegen subtrees.")) + // assuming that code is generated for the test query + assert(codegenStr.contains("Generated code:")) + + val codegenStrSeq = codegenStringSeq(query) + assert(codegenStrSeq.nonEmpty) + assert(codegenStrSeq.head._1.contains("*(1)")) + assert(codegenStrSeq.head._2.contains("codegenStageId=1")) + } + test("SPARK-19065: dropDuplicates should not create expressions using the same id") { withTempPath { testPath => val data = Seq((1, 2), (2, 3), (3, 4)) @@ -805,6 +921,114 @@ class StreamSuite extends StreamTest { } } + test("streaming limit without state") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(0))( + AddData(inputData1, 1 to 8: _*), + CheckAnswer()) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4))( + AddData(inputData2, 1 to 8: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with state") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().limit(4))( + AddData(inputData, 1 to 2: _*), + CheckAnswer(1 to 2: _*), + AddData(inputData, 3 to 6: _*), + CheckAnswer(1 to 4: _*), + AddData(inputData, 7 to 9: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with other operators") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().where("value % 2 = 1").limit(4))( + AddData(inputData, 1 to 5: _*), + CheckAnswer(1, 3, 5), + AddData(inputData, 6 to 9: _*), + CheckAnswer(1, 3, 5, 7), + AddData(inputData, 10 to 12: _*), + CheckAnswer(1, 3, 5, 7)) + } + + test("streaming limit with multiple limits") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(4).limit(2))( + AddData(inputData1, 1), + CheckAnswer(1), + AddData(inputData1, 2 to 8: _*), + CheckAnswer(1, 2)) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4).limit(100).limit(3))( + AddData(inputData2, 1, 2), + CheckAnswer(1, 2), + AddData(inputData2, 3 to 8: _*), + CheckAnswer(1 to 3: _*)) + } + + test("streaming limit in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(5).groupBy("value").count() + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 3: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 2), Row(2, 2), Row(3, 2), Row(4, 1), Row(5, 1))) + } + + test("streaming limits in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(4).groupBy("value").count().orderBy("value").limit(3) + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 2 to 6: _*), + CheckAnswer(Row(1, 1), Row(2, 2), Row(3, 2))) + } + + test("streaming limit in update mode") { + val inputData = MemoryStream[Int] + val e = intercept[AnalysisException] { + testStream(inputData.toDF().limit(5), OutputMode.Update())( + AddData(inputData, 1 to 3: _*) + ) + } + assert(e.getMessage.contains( + "Limits are not supported on streaming DataFrames/Datasets in Update output mode")) + } + + test("streaming limit in multiple partitions") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().repartition(2).limit(7))( + AddData(inputData, 1 to 10: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false), + AddData(inputData, 11 to 20: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false)) + } + + test("streaming limit in multiple partitions by column") { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF().repartition(2, $"_2").limit(7) + testStream(df)( + AddData(inputData, (1, 0), (2, 0), (3, 1), (4, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 4 && rows.forall(r => r.getInt(0) <= 4)), + false), + AddData(inputData, (5, 0), (6, 0), (7, 1), (8, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 8)), + false)) + } + for (e <- Seq( new InterruptedException, new InterruptedIOException, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af0268fa47871..491dc34afa143 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -192,14 +193,30 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + private def operatorName = if (lastOnly) "CheckLastBatchContains" else "CheckAnswerContains" } case class CheckAnswerRowsByFunc( globalCheckFunction: Seq[Row] => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName" - private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + override def toString: String = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + } + + case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"CheckNewAnswer: ${expectedAnswer.mkString(",")}" + } + + object CheckNewAnswer { + def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty) + + def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + + def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows) } /** Stops the stream. It must currently be running. */ @@ -274,8 +291,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { - def apply(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }) + def apply(name: String)(func: StreamExecution => Any): AssertOnQuery = + AssertOnQuery(query => { func(query); true }, "name") + + def apply(func: StreamExecution => Any): AssertOnQuery = apply("Execute")(func) } object AwaitEpoch { @@ -435,13 +454,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + var lastFetchedMemorySinkLastBatchId: Long = -1 + + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") verify(currentStream != null, "stream not running") // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(sourceIndex, offset) + currentStream.awaitOffset(sourceIndex, offset, streamingTimeout.toMillis) + // Make sure all processing including no-data-batches have been executed + if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + currentStream.processAllAvailable() + } } } @@ -463,21 +493,28 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - val (latestBatchData, allData) = sink match { - case s: MemorySink => (s.latestBatchData, s.allData) - case s: MemorySinkV2 => (s.latestBatchData, s.allData) - } - try if (lastOnly) latestBatchData else allData catch { + val rows = try { + if (sinceLastFetchOnly) { + if (sink.latestBatchId.getOrElse(-1L) < lastFetchedMemorySinkLastBatchId) { + failTest("MemorySink was probably cleared since last fetch. Use CheckAnswer instead.") + } + sink.dataSinceBatch(lastFetchedMemorySinkLastBatchId) + } else { + if (lastOnly) sink.latestBatchData else sink.allData + } + } catch { case e: Exception => failTest("Exception while getting data from sink", e) } + lastFetchedMemorySinkLastBatchId = sink.latestBatchId.getOrElse(-1L) + rows } def executeAction(action: StreamAction): Unit = { logInfo(s"Processing test stream action: $action") action match { case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => - verify(currentStream == null, "stream already running") + verify(currentStream == null || !currentStream.isActive, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], "Use either SystemClock or StreamManualClock to start the stream") @@ -649,7 +686,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.reader + case r: StreamingDataSourceV2Relation => r.readSupport } .zipWithIndex .find(_._1 == source) @@ -698,14 +735,22 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } try { globalCheckFunction(sparkAnswer) } catch { case e: Throwable => failTest(e.toString) } + + case CheckNewAnswerRows(expectedAnswer) => + val sparkAnswer = fetchStreamAnswer(currentStream, sinceLastFetchOnly = true) + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } } - pos += 1 } try { @@ -719,8 +764,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { actns.foreach(executeAction) } + pos += 1 - case action: StreamAction => executeAction(action) + case action: StreamAction => + executeAction(action) + pos += 1 } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 1cae8cb8d47f1..1ae6ff3a90989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.{Locale, TimeZone} -import org.scalatest.Assertions -import org.scalatest.BeforeAndAfterAll +import org.apache.commons.io.FileUtils +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.rdd.BlockRDD @@ -31,13 +32,15 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} +import org.apache.spark.util.Utils object FailureSingleton { var firstTime = true @@ -53,7 +56,35 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ - test("simple count, update mode") { + def executeFuncWithStateVersionSQLConf( + stateVersion: Int, + confPairs: Seq[(String, String)], + func: => Any): Unit = { + withSQLConf(confPairs ++ + Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) { + func + } + } + + def testWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + test(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } + + def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + testQuietly(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } + + testWithAllStateVersions("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -77,7 +108,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("count distinct") { + testWithAllStateVersions("count distinct") { val inputData = MemoryStream[(Int, Seq[Int])] val aggregated = @@ -93,7 +124,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, complete mode") { + testWithAllStateVersions("simple count, complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -116,7 +147,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, append mode") { + testWithAllStateVersions("simple count, append mode") { val inputData = MemoryStream[Int] val aggregated = @@ -133,7 +164,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("sort after aggregate in complete mode") { + testWithAllStateVersions("sort after aggregate in complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -158,7 +189,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("state metrics") { + testWithAllStateVersions("state metrics") { val inputData = MemoryStream[Int] val aggregated = @@ -211,7 +242,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("multiple keys") { + testWithAllStateVersions("multiple keys") { val inputData = MemoryStream[Int] val aggregated = @@ -228,7 +259,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testQuietly("midbatch failure") { + testQuietlyWithAllStateVersions("midbatch failure") { val inputData = MemoryStream[Int] FailureSingleton.firstTime = true val aggregated = @@ -254,7 +285,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("typed aggregators") { + testWithAllStateVersions("typed aggregators") { val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) @@ -264,7 +295,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_time, complete mode") { + testWithAllStateVersions("prune results by current_time, complete mode") { import testImplicits._ val clock = new StreamManualClock val inputData = MemoryStream[Long] @@ -316,7 +347,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_date, complete mode") { + testWithAllStateVersions("prune results by current_date, complete mode") { import testImplicits._ val clock = new StreamManualClock val tz = TimeZone.getDefault.getID @@ -365,7 +396,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-19690: do not convert batch aggregation in streaming query to streaming") { + testWithAllStateVersions("SPARK-19690: do not convert batch aggregation in streaming query " + + "to streaming") { val streamInput = MemoryStream[Int] val batchDF = Seq(1, 2, 3, 4, 5) .toDF("value") @@ -429,7 +461,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest true } - test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with 0 partition RDD should be " + + "repartitioned to 1") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default @@ -467,8 +500,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + - "has non-empty grouping keys") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with aggregation should still be " + + "repartitioned when it has non-empty grouping keys") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => @@ -520,7 +553,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-22230: last should change with new batches") { + testWithAllStateVersions("SPARK-22230: last should change with new batches") { val input = MemoryStream[Int] val aggregated = input.toDF().agg(last('value)) @@ -536,6 +569,82 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } + testWithAllStateVersions("SPARK-23004: Ensure that TypedImperativeAggregate functions " + + "do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error + // by ensuring the following. + // - A streaming query with a streaming aggregation. + // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate. + // - Post shuffle partition has exactly 128 records (i.e. the threshold at which + // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a + // micro-batch with 128 records that shuffle to a single partition. + // This test throws the exact error reported in SPARK-23004 without the corresponding fix. + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + + + test("simple count, update mode - recovery from checkpoint uses state format version 1") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(3) + inputData.addData(3, 2) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)) + */ + + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + + Execute { query => + // Verify state format = 1 + val stateVersions = query.lastExecution.executedPlan.collect { + case f: StateStoreSaveExec => f.stateFormatVersion + case f: StateStoreRestoreExec => f.stateFormatVersion + } + assert(stateVersions.size == 2) + assert(stateVersions.forall(_ == 1)) + }, + + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) + } + + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 0088b64d6195e..42ffd472eb843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -97,28 +98,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { testStream(result, Append)( AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), - CheckLastBatch(10 to 15: _*), + CheckAnswer(10 to 15: _*), assertNumStateRows(total = 6, updated = 6), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(25), - assertNumStateRows(total = 7, updated = 1), - - AddData(inputData, 25), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0), + AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch drops rows <= 15 + CheckNewAnswer(25), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(total = 1, updated = 0), - AddData(inputData, 45), // Advance watermark to 35 seconds - CheckLastBatch(45), - assertNumStateRows(total = 2, updated = 1), - - AddData(inputData, 45), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0) + AddData(inputData, 45), // Advance watermark to 35 seconds, no-data-batch drops row 25 + CheckNewAnswer(45), + assertNumStateRows(total = 1, updated = 1) ) } @@ -141,33 +134,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) - // states in deduplicate is 10 to 15 and 25 - assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), - - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate - // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of - // window to evict items, so [15, 20) is still in the state store) - // states in deduplicate is 25 - assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate, emitted with no-data-batch + // states in aggregate in [15, 20) and [25, 30); no-data-batch removed [10, 14) + // states in deduplicate is 25, no-data-batch removed 10 to 14 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(1L, 1L)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckLastBatch(), assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), AddData(inputData, 40), // Advance watermark to 30 seconds - CheckLastBatch(), - // states in aggregate in [15, 20), [25, 30) and [40, 45) - // states in deduplicate is 25 and 40, - assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), - - AddData(inputData, 40), // Emit items less than watermark and drop their state CheckLastBatch((15 -> 1), (25 -> 1)), - // states in aggregate in [40, 45) - // states in deduplicate is 40, - assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + // states in aggregate is [40, 45); no-data-batch removed [15, 20) and [25, 30) + // states in deduplicate is 40; no-data-batch removed 25 + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)) ) } @@ -260,13 +240,13 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id") testStream(df)( AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), - CheckLastBatch(1, 2), + CheckAnswer(1, 2), AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), - CheckLastBatch(3, 4), + CheckNewAnswer(3, 4), AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark - CheckLastBatch(5, 6), + CheckNewAnswer(5, 6), AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark - CheckLastBatch(7) + CheckNewAnswer(7) ) } @@ -279,7 +259,37 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id", $"time".cast("long")) testStream(df)( AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), - CheckLastBatch(1 -> 1, 2 -> 2) + CheckAnswer(1 -> 1, 2 -> 2) ) } + + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(10, 11, 12, 13, 14, 15), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer(25), + { // State should have been cleaned if flag is set, otherwise should not have been cleaned + if (flag) assertNumStateRows(total = 1, updated = 1) + else assertNumStateRows(total = 7, updated = 1) + } + ) + } + + testWithFlag(true) + testWithFlag(false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 11bdd13942dcb..c5cc8df4356a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -62,20 +62,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1), CheckAnswer(), AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), AddData(input1, 10), // 10 arrived on input2 first, then input1, should join - CheckLastBatch((10, 20, 30)), + CheckNewAnswer((10, 20, 30)), AddData(input2, 1), // another 1 in input2 should join with 1 input1 - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), StopStream, StartStream(), AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) - CheckLastBatch((1, 2, 3), (1, 2, 3)), + CheckNewAnswer((1, 2, 3), (1, 2, 3)), StopStream, StartStream(), AddData(input1, 100), AddData(input2, 100), - CheckLastBatch((100, 200, 300)) + CheckNewAnswer((100, 200, 300)) ) } @@ -97,25 +97,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( AddData(input1, 1), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckNewAnswer((1, 10, 2, 3)), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), + CheckNewAnswer(), StopStream, StartStream(), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), + CheckNewAnswer((25, 30, 50, 75)), AddData(input1, 1), - CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is no watermark + CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark StopStream, StartStream(), AddData(input1, 5), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 5), - CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark + CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark ) } @@ -142,27 +142,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assertNumStateRows(total = 1, updated = 1), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckAnswer((1, 10, 2, 3)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15 - assertNumStateRows(total = 3, updated = 1), + CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10] + assertNumStateRows(total = 1, updated = 1), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10] + CheckNewAnswer((25, 30, 50, 75)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input2, 1), - CheckLastBatch(), // Should not join as < 15 removed - assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + CheckNewAnswer(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 AddData(input1, 5), - CheckLastBatch(), // Should not join or add to state as < 15 got filtered by watermark + CheckNewAnswer(), // Same reason as above assertNumStateRows(total = 2, updated = 0) ) } @@ -189,42 +189,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5)), CheckAnswer(), AddData(rightInput, (1, 11)), - CheckLastBatch((1, 5, 11)), + CheckNewAnswer((1, 5, 11)), AddData(rightInput, (1, 10)), - CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 1), + CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 + assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), - CheckLastBatch((1, 3, 10), (1, 3, 11)), + CheckNewAnswer((1, 3, 10), (1, 3, 11)), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer(), // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // so left side going to only receive data where leftTime > 20 // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 - - // Run another batch with event time = 25 to clear right state where rightTime <= 25 - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 1), // removed (1, 11) and (1, 10), added (0, 30) + // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed + assertNumStateRows(total = 4, updated = 1), // New data to right input should match with left side (1, 3) and (1, 5), as left state should // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and // state rows with rightTime <= 25 should be removed from state. // (1, 20) ==> filtered by event time watermark = 20 // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state - // as state watermark = 25 + // as 21 < state watermark = 25 // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state AddData(rightInput, (1, 20), (1, 21), (1, 28)), - CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 5, updated = 1), // New data to left input with leftTime <= 20 should be filtered due to event time watermark AddData(leftInput, (1, 20), (1, 21)), - CheckLastBatch((1, 21, 28)), - assertNumStateRows(total = 7, updated = 1) + CheckNewAnswer((1, 21, 28)), + assertNumStateRows(total = 6, updated = 1) ) } @@ -275,38 +272,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 20)), CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), - CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 6), + CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), - CheckLastBatch(), // matches with nothing on the left + CheckNewAnswer(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), - CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 4), + CheckNewAnswer((1, 50, 60), (1, 65, 60)), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) // Should drop < 20 from left, i.e., none // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) // Should drop < 25 from the right, i.e., 14 and 15 - AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to stat - CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)), - assertNumStateRows(total = 11, updated = 1), // 12 - 2 removed + 1 added + assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed + + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state + CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1), // only 31 added // Advance the watermark AddData(rightInput, (1, 80)), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 1), - + CheckNewAnswer(), // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) // Should drop < 36 from left, i.e., 20, 31 (30 was not added) // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) // Should drop < 41 from the right, i.e., 25, 26, 30, 31 - AddData(rightInput, (1, 50)), - CheckLastBatch((1, 49, 50), (1, 50, 50)), - assertNumStateRows(total = 7, updated = 1) // 12 - 6 removed + 1 added + assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed + + AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state + CheckNewAnswer((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1) // 50 added ) } @@ -322,7 +320,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with input1.addData(1) q.awaitTermination(10000) } - assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) + assert(e.toString.contains("Stream-stream join without equality predicate is not supported")) } test("stream stream self join") { @@ -404,10 +402,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1, 5), AddData(input2, 1, 5, 10), AddData(input3, 5, 10), - CheckLastBatch((5, 10, 5, 15, 5, 25))) + CheckNewAnswer((5, 10, 5, 15, 5, 25))) + } + + test("streaming join should require HashClusteredDistribution from children") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) + val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) + val joined = df1.join(df2, Seq("a", "b")).select('a) + + testStream(joined)( + AddData(input1, 1.to(1000): _*), + AddData(input2, 1.to(1000): _*), + CheckAnswer(1.to(1000): _*)) } } + class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { import testImplicits._ @@ -465,13 +478,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -492,15 +505,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), - // The right rows with value <= 7 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The right rows with rightValue <= 7 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // rightValue = 9 > 7 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, 8, null), Row(5, 10, 10, null)), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -521,15 +534,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), - // The left rows with value <= 4 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The left rows with leftValue <= 4 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // leftValue = 7 > 4 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, null, "12"), Row(5, 10, null, "15")), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -552,13 +565,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -568,14 +581,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -586,14 +599,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -627,21 +640,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5), (3, 5)), CheckAnswer(), AddData(rightInput, (1, 10), (2, 5)), - CheckLastBatch((1, 1, 5, 10)), + CheckNewAnswer((1, 1, 5, 10)), AddData(rightInput, (1, 11)), - CheckLastBatch(), // no match as left time is too low - assertNumStateRows(total = 5, updated = 1), + CheckNewAnswer(), // no match as left time is too low + assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), - CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + CheckNewAnswer((1, 1, 7, 10), (1, 1, 7, 11)), assertNumStateRows(total = 7, updated = 2), - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 1), - AddData(rightInput, (0, 30)), - CheckLastBatch(outerResult), - assertNumStateRows(total = 3, updated = 1) + AddData(rightInput, (0, 30)), // watermark = 30 - 10 = 20, no-data-batch computes nulls + CheckNewAnswer(outerResult), + assertNumStateRows(total = 2, updated = 1) ) } } @@ -665,36 +675,41 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), - CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 2), + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 3), // only right 1, 2, 3 added + + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch cleared < 10 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 20 and 21 left in state + AddData(rightInput, 20), - CheckLastBatch( - Row(20, 30, 40, 60)), + CheckNewAnswer(Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows - MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), - CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - MultiAddData(leftInput, 70)(rightInput, 71), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 2), + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), // watermark = 31 + CheckNewAnswer((40, 50, 80, 120), (41, 50, 82, 123)), + assertNumStateRows(total = 4, updated = 4), // only left 40, 41 + right 40,41 left in state + + MultiAddData(leftInput, 70)(rightInput, 71), // watermark = 60 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 70, 71 left in state + AddData(rightInput, 70), - CheckLastBatch((70, 80, 140, 210)), + CheckNewAnswer((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left - MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), - CheckLastBatch(), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), // watermark = 91 + CheckNewAnswer(), + assertNumStateRows(total = 6, updated = 3), // only 101 - 103 left in state + MultiAddData(leftInput, 1000)(rightInput, 1001), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), - AddData(rightInput, 1000), - CheckLastBatch( - Row(1000, 1010, 2000, 3000), + CheckNewAnswer( Row(101, 110, 202, null), Row(102, 110, 204, null), Row(103, 110, 206, null)), - assertNumStateRows(total = 3, updated = 1) + assertNumStateRows(total = 2, updated = 2) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b96f2bcbdd644..fe77a1b4469c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -231,7 +231,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { test("event ordering") { val listener = new EventCollector withListenerAdded(listener) { - for (i <- 1 to 100) { + for (i <- 1 to 50) { listener.reset() require(listener.startEvent === null) testStream(MemoryStream[Int].toDS)( @@ -299,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getEndOffset: OffsetV2 = { + override def latestOffset(): OffsetV2 = { numTriggers += 1 - super.getEndOffset + super.latestOffset() } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala new file mode 100644 index 0000000000000..1aaf8a9aa2d55 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.language.reflectiveCalls + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener._ + + +class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + + override protected def sparkConf: SparkConf = + super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", + "org.apache.spark.sql.streaming.TestListener") + + test("test if the configured query lister is loaded") { + testStream(MemoryStream[Int].toDS)( + StartStream(), + StopStream + ) + + assert(TestListener.queryStartedEvent != null) + assert(TestListener.queryTerminatedEvent != null) + } + +} + +object TestListener { + @volatile var queryStartedEvent: QueryStartedEvent = null + @volatile var queryTerminatedEvent: QueryTerminatedEvent = null +} + +class TestListener(sparkConf: SparkConf) extends StreamingQueryListener { + + override def onQueryStarted(event: QueryStartedEvent): Unit = { + TestListener.queryStartedEvent = event + } + + override def onQueryProgress(event: QueryProgressEvent): Unit = {} + + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = { + TestListener.queryTerminatedEvent = event + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 79bb827e0de93..7bef687e7e43b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -58,7 +58,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "stateOperators" : [ { | "numRowsTotal" : 0, | "numRowsUpdated" : 1, - | "memoryUsedBytes" : 2 + | "memoryUsedBytes" : 3, + | "customMetrics" : { + | "loadedMapCacheHitCount" : 1, + | "loadedMapCacheMissCount" : 0, + | "stateOnCurrentVersionSizeBytes" : 2 + | } | } ], | "sources" : [ { | "description" : "source", @@ -230,7 +235,11 @@ object StreamingQueryStatusAndProgressSuite { "avg" -> "2016-12-05T20:54:20.827Z", "watermark" -> "2016-12-05T20:54:20.827Z").asJava), stateOperators = Array(new StateOperatorProgress( - numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2)), + numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 3, + customMetrics = new java.util.HashMap(Map("stateOnCurrentVersionSizeBytes" -> 2L, + "loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L) + .mapValues(long2Long).asJava) + )), sources = Array( new SourceProgress( description = "source", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 20942ed93897c..73592526fb0f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.CountDownLatch +import scala.collection.mutable + import org.apache.commons.lang3.RandomStringUtils +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout @@ -29,13 +31,13 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -212,25 +214,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi private def dataAdded: Boolean = currentOffset.offset != -1 - // setOffsetRange should take 50 ms the first time it is called after data is added - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - if (dataAdded) clock.waitTillTime(1050) - super.setOffsetRange(start, end) - } - } - - // getEndOffset should take 100 ms the first time it is called after data is added - override def getEndOffset(): OffsetV2 = synchronized { - if (dataAdded) clock.waitTillTime(1150) - super.getEndOffset() + // latestOffset should take 50 ms the first time it is called after data is added + override def latestOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.latestOffset() } // getBatch should take 100 ms the first time it is called - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { synchronized { - clock.waitTillTime(1350) - super.createUnsafeRowReaderFactories() + clock.waitTillTime(1150) + super.planInputPartitions(config) } } } @@ -271,33 +265,26 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress when setOffsetRange is being called + // Test status and progress when `latestOffset` is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on `latestOffset` AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange + AdvanceManualClock(50), // time = 1050 to unblock `latestOffset` AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 - AssertOnQuery(_.status.isDataAvailable === false), - AssertOnQuery(_.status.isTriggerActive === true), - AssertOnQuery(_.status.message.startsWith("Getting offsets from")), - AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - - AdvanceManualClock(100), // time = 1150 to unblock getEndOffset - AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 + // will block on `planInputPartitions` that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1150), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock createReadTasks - AssertClockTime(1350), + AdvanceManualClock(100), // time = 1150 to unblock `planInputPartitions` + AssertClockTime(1150), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), @@ -305,7 +292,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(150), // time = 1500 to unblock map task + AdvanceManualClock(350), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger @@ -325,17 +312,16 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("setOffsetRange") === 50) - assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 200) + assert(progress.durationMs.get("latestOffset") === 50) + assert(progress.durationMs.get("queryPlanning") === 100) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 150) + assert(progress.durationMs.get("addBatch") === 350) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === null) - assert(progress.sources(0).endOffset !== null) + assert(progress.sources(0).startOffset === null) // no prior offset + assert(progress.sources(0).endOffset === "0") assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) @@ -361,6 +347,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(query.lastProgress.batchId === 1) assert(query.lastProgress.inputRowsPerSecond === 2.0) assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).startOffset === "0") + assert(query.lastProgress.sources(0).endOffset === "1") true }, @@ -461,12 +449,50 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0) assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0) assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0) + assert(gauges.get("eventTime-watermark").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("states-rowsTotal").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("states-usedBytes").getValue.asInstanceOf[Long] == 0) sq.stop() } } } - test("input row calculation with mixed batch and streaming sources") { + test("Check if custom metrics are reported") { + val streamInput = MemoryStream[Int] + implicit val formats = Serialization.formats(NoTypeHints) + testStream(streamInput.toDF(), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sink.customMetrics == "{\"numRows\":3}") + true + }, + AddData(streamInput, 4, 5, 6, 7), + CheckAnswer(1, 2, 3, 4, 5, 6, 7), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 4) + assert(lastProgress.get.sink.customMetrics == "{\"numRows\":7}") + true + } + ) + } + + test("input row calculation with same V1 source used twice in self-join") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + + val progress = getFirstProgress(streamingInputDF.join(streamingInputDF, "value")) + assert(progress.numInputRows === 20) // data is read multiple times in self-joins + assert(progress.sources.size === 1) + assert(progress.sources(0).numInputRows === 20) + } + + test("input row calculation with mixed batch and streaming V1 sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") @@ -479,7 +505,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } - test("input row calculation with trigger input DF having multiple leaves") { + test("input row calculation with trigger input DF having multiple leaves in V1 source") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) @@ -492,6 +518,121 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } + test("input row calculation with same V2 source used twice in self-union") { + val streamInput = MemoryStream[Int] + + testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 1, 2, 2, 3, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with same V2 source used twice in self-join") { + val streamInput = MemoryStream[Int] + val df = streamInput.toDF() + testStream(df.join(df, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with trigger having data for only one of two V2 sources") { + val streamInput1 = MemoryStream[Int] + val streamInput2 = MemoryStream[Int] + + testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + AddData(streamInput1, 1, 2, 3), + CheckLastBatch(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 3) + assert(lastProgress.get.sources(1).numInputRows == 0) + true + }, + AddData(streamInput2, 4, 5), + CheckLastBatch(4, 5), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 2) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 0) + assert(lastProgress.get.sources(1).numInputRows == 2) + true + } + ) + } + + test("input row calculation with mixed batch and streaming V2 sources") { + + val streamInput = MemoryStream[Int] + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + + // The number of leaves in the trigger's logical plan should be same as the executed plan. + require( + q.lastExecution.logical.collectLeaves().length == + q.lastExecution.executedPlan.collectLeaves().length) + + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + + val streamInput2 = MemoryStream[Int] + val staticInputDF2 = staticInputDF.union(staticInputDF).cache() + + testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + AddData(streamInput2, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + // The number of leaves in the trigger's logical plan should be different from + // the executed plan. The static input will have two leaves in the logical plan + // (due to the union), but will be converted to a single leaf in the executed plan + // (due to the caching, the cached subplan is replaced by a single InMemoryTableScanExec). + require( + q.lastExecution.logical.collectLeaves().length != + q.lastExecution.executedPlan.collectLeaves().length) + + // Despite the mismatch in total number of leaves in the logical and executed plans, + // we should be able to attribute streaming input metrics to the streaming sources. + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) @@ -706,6 +847,77 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckLastBatch(("A", 1))) } + test("Uuid in streaming query should not produce same uuids in each execution") { + val uuids = mutable.ArrayBuffer[String]() + def collectUuid: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach(r => uuids += r.getString(0)) + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(Uuid())) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectUuid), + AddData(stream, 2), + CheckAnswer(collectUuid) + ) + assert(uuids.distinct.size == 2) + } + + test("Rand/Randn in streaming query should not produce same results in each execution") { + val rands = mutable.ArrayBuffer[Double]() + def collectRand: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach { r => + rands += r.getDouble(0) + rands += r.getDouble(1) + } + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(new Rand()), new Column(new Randn())) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectRand), + AddData(stream, 2), + CheckAnswer(collectRand) + ) + assert(rands.distinct.size == 4) + } + + test("Shuffle in streaming query should not produce same results in each execution") { + val rands = mutable.ArrayBuffer[Seq[Int]]() + def collectShuffle: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach { r => + rands += r.getSeq[Int](0) + } + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(new Shuffle(Literal.create[Seq[Int]](0 until 100)))) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectShuffle), + AddData(stream, 2), + CheckAnswer(collectShuffle) + ) + assert(rands.distinct.size == 2) + } + + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + + "should not fail") { + val df = spark.readStream.format("rate").load() + assert(df.logicalPlan.toJSON.contains("StreamingRelationV2")) + + testStream(df)( + AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation")) + ) + + testStream(df, useV2Sink = true)( + StartStream(trigger = Trigger.Continuous(100)), + AssertOnQuery(_.logicalPlan.toJSON.contains("ContinuousExecutionRelation")) + ) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) @@ -733,6 +945,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + /** Returns the last query progress from query.recentProgress where numInputRows is positive */ + def getLastProgressWithData(q: StreamingQuery): Option[StreamingQueryProgress] = { + q.recentProgress.filter(_.numInputRows > 0).lastOption + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala new file mode 100644 index 0000000000000..c5b95fa9b64a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.OutputMode + +class ContinuousAggregationSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("not enabled") { + val ex = intercept[AnalysisException] { + val input = ContinuousMemoryStream.singlePartition[Int] + testStream(input.toDF().agg(max('value)), OutputMode.Complete)() + } + + assert(ex.getMessage.contains( + "In continuous processing mode, coalesce(1) must be called before aggregate operation")) + } + + test("basic") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + } + + test("multiple partitions with coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF().coalesce(1).agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with coalesce - multiple transformations") { + val input = ContinuousMemoryStream[Int] + + // We use a barrier to make sure predicates both before and after coalesce work + val df = input.toDF() + .select('value as 'copy, 'value) + .where('copy =!= 1) + .logicalPlan + .coalesce(1) + .where('copy =!= 2) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(0), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with multiple coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF() + .coalesce(1) + .logicalPlan + .coalesce(1) + .select('value as 'copy, 'value) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("repeated restart") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + StartStream(), + StopStream, + StartStream(), + StopStream, + StartStream(), + AddData(input, 0), + CheckAnswer(2), + AddData(input, 5), + CheckAnswer(5)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala new file mode 100644 index 0000000000000..d6819eacd07ca --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} + +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} + +class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { + case class LongPartitionOffset(offset: Long) extends PartitionOffset + + val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest" + val startEpoch = 0 + + var epochEndpoint: RpcEndpointRef = _ + + override def beforeEach(): Unit = { + super.beforeEach() + epochEndpoint = EpochCoordinatorRef.create( + mock[StreamingWriteSupport], + mock[ContinuousReadSupport], + mock[ContinuousExecution], + coordinatorId, + startEpoch, + spark, + SparkEnv.get) + EpochTracker.initializeCurrentEpoch(0) + } + + override def afterEach(): Unit = { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochEndpoint = null + super.afterEach() + } + + + private val mockContext = mock[TaskContext] + when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY)) + .thenReturn(startEpoch.toString) + when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)) + .thenReturn(coordinatorId) + + /** + * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send + * rows to the wrapped data reader. + */ + private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { + val queue = new ArrayBlockingQueue[UnsafeRow](1024) + val partitionReader = new ContinuousPartitionReader[InternalRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } + + override def get = curr + + override def getOffset = LongPartitionOffset(index) + + override def close() = {} + } + val reader = new ContinuousQueuedDataReader( + 0, + partitionReader, + new StructType().add("i", "int"), + mockContext, + dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, + epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) + + (queue, reader) + } + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + test("basic data read") { + val (input, reader) = setup() + + input.add(unsafeRow(12345)) + assert(reader.next().getInt(0) == 12345) + } + + test("basic epoch marker") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } + + test("new rows after markers") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + } + + test("new markers after rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + } + + test("alternating markers and rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + assert(reader.next().getInt(0) == 11111) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + input.add(unsafeRow(33333)) + assert(reader.next().getInt(0) == 33333) + input.add(unsafeRow(44444)) + assert(reader.next().getInt(0) == 44444) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index c318b951ff992..3d21bc63e0cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r }.get val deltaMs = numTriggers * 1000 + 300 @@ -75,73 +75,50 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("map") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .map(r => r.getLong(0) * 2) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().map(_.getInt(0) * 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(0, 2), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(0, 2, 4, 6, 8)) } test("flatMap") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().flatMap(r => Seq(0, r.getInt(0), r.getInt(0) * 2)) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer((0 to 1).flatMap(n => Seq(0, n, n * 2)): _*), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer((0 to 4).flatMap(n => Seq(0, n, n * 2)): _*)) } test("filter") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .where('value > 5) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().where('value > 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("deduplicate") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .dropDuplicates() + val input = ContinuousMemoryStream[Int] + val df = input.toDF().dropDuplicates() val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -149,15 +126,11 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("timestamp") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select(current_timestamp()) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().select(current_timestamp()) val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -165,58 +138,43 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("subquery alias") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .createOrReplaceTempView("rate") - val test = spark.sql("select value from rate where value > 5") + val input = ContinuousMemoryStream[Int] + input.toDF().createOrReplaceTempView("memory") + val test = spark.sql("select value from memory where value > 2") - testStream(test, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(test)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("repeatedly restart") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(df)( + StartStream(), + AddData(input, 0, 1), + CheckAnswer(0, 1), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StartStream(), + StopStream, + AddData(input, 2, 3), + StartStream(), + CheckAnswer(0, 1, 2, 3), StopStream) } test("task failure kills the query") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() // Get an arbitrary task from this query to kill. It doesn't matter which one. var taskId: Long = -1 @@ -227,9 +185,9 @@ class ContinuousSuite extends ContinuousSuiteBase { } spark.sparkContext.addSparkListener(listener) try { - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(100)), - Execute(waitForRateSourceTriggers(_, 2)), + AddData(input, 0, 1, 2, 3), Execute { _ => // Wait until a task is started, then kill its first attempt. eventually(timeout(streamingTimeout)) { @@ -252,6 +210,7 @@ class ContinuousSuite extends ContinuousSuiteBase { .option("rowsPerSecond", "2") .load() .select('value) + val query = df.writeStream .format("memory") .queryName("noharness") @@ -338,3 +297,49 @@ class ContinuousStressSuite extends ContinuousSuiteBase { CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } + +class ContinuousMetaSuite extends ContinuousSuiteBase { + import testImplicits._ + + // We need to specify spark.sql.streaming.minBatchesToRetain to do the following test. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true") + .set("spark.sql.streaming.minBatchesToRetain", "2"))) + + test("SPARK-24351: check offsetLog/commitLog retained in the checkpoint directory") { + withTempDir { checkpointDir => + val input = ContinuousMemoryStream[Int] + val df = input.toDF().mapPartitions(iter => { + // Sleep the task thread for 300 ms to make sure epoch processing time 3 times + // longer than epoch creating interval. So the gap between last committed + // epoch and currentBatchId grows over time. + Thread.sleep(300) + iter.map(row => row.getInt(0) * 2) + }) + + testStream(df)( + StartStream(trigger = Trigger.Continuous(100), + checkpointLocation = checkpointDir.getAbsolutePath), + AddData(input, 1), + CheckAnswer(2), + // Make sure epoch 2 has been committed before the following validation. + AwaitEpoch(2), + StopStream, + AssertOnQuery(q => { + q.commitLog.getLatest() match { + case Some((latestEpochId, _)) => + val commitLogValidateResult = q.commitLog.get(latestEpochId - 1).isDefined && + q.commitLog.get(latestEpochId - 2).isEmpty + val offsetLogValidateResult = q.offsetLog.get(latestEpochId - 1).isDefined && + q.offsetLog.get(latestEpochId - 2).isEmpty + commitLogValidateResult && offsetLogValidateResult + case None => false + } + }) + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 99e30561f81d5..3c973d8ebc704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,20 +40,20 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writer: StreamWriter = _ + private var writeSupport: StreamingWriteSupport = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { - val reader = mock[ContinuousReader] - writer = mock[StreamWriter] + val reader = mock[ContinuousReadSupport] + writeSupport = mock[StreamingWriteSupport] query = mock[ContinuousExecution] - orderVerifier = inOrder(writer, query) + orderVerifier = inOrder(writeSupport, query) spark = new TestSparkSession() epochCoordinator - = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get) } test("single epoch") { @@ -120,7 +120,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + test("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { setWriterPartitions(2) setReaderPartitions(2) @@ -141,7 +141,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) @@ -162,7 +162,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) @@ -209,12 +209,12 @@ class EpochCoordinatorSuite } private def verifyCommit(epoch: Long): Unit = { - orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(writeSupport).commit(eqTo(epoch), any()) orderVerifier.verify(query).commit(epoch) } private def verifyNoCommitFor(epoch: Long): Unit = { - verify(writer, never()).commit(eqTo(epoch), any()) + verify(writeSupport, never()).commit(eqTo(epoch), any()) verify(query, never()).commit(epoch) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala new file mode 100644 index 0000000000000..b42f8267916b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import scala.language.implicitConversions + +import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ContinuousShuffleSuite extends StreamTest { + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + TaskContext.unset() + ctx = null + super.afterEach() + } + + private implicit def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) + } + + private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + + private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { + rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + } + + private def readEpoch(rdd: ContinuousShuffleReadRDD) = { + rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) + } + + test("reader - one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) + ) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) + } + + test("reader - multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) + } + + test("reader - empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0) + ) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } + + test("reader - multiple partitions") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, + numPartitions = 5, + endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}")) + // Send all data before processing to ensure there's no crossover. + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index for identification. + send( + part.endpoint, + ReceiverRow(0, unsafeRow(part.index)), + ReceiverEpochMarker(0) + ) + } + + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } + + test("reader - blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) + val epoch = rdd.compute(rdd.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + try { + epoch.next().getInt(0) + } catch { + case _: InterruptedException => // do nothing - expected at test ending + } + } + } + + try { + readRowThread.start() + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + } finally { + readRowThread.interrupt() + readRowThread.join() + } + } + + test("reader - multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(1), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + + test("reader - epoch only ends when all writers send markers") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(2) + ) + + val epoch = rdd.compute(rdd.partitions(0), ctx) + val rows = (0 until 3).map(_ => epoch.next()).toSet + assert(rows.map(_.getUTF8String(0).toString) == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + + // After checking the right rows, block until we get an epoch marker indicating there's no next. + // (Also fail the assertion if for some reason we get a row.) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!epoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + // Send the last epoch marker - now the epoch should finish. + send(endpoint, ReceiverEpochMarker(1)) + eventually(timeout(streamingTimeout)) { + !readEpochMarkerThread.isAlive + } + + // Join to pick up assertion failures. + readEpochMarkerThread.join(streamingTimeout.toMillis) + } + + test("reader - writer epochs non aligned") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should + // collate them as though the markers were aligned in the first place. + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow("writer0-row1")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row1")), + ReceiverEpochMarker(1), + + ReceiverEpochMarker(2), + ReceiverEpochMarker(2), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(firstEpoch == Set("writer0-row0")) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(secondEpoch == Set("writer0-row1", "writer1-row0")) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) + } + + test("one epoch") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + } + + test("multiple epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + writer.write(Iterator(4, 5, 6)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + assert(readEpoch(reader) == Seq(4, 5, 6)) + } + + test("empty epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator()) + writer.write(Iterator(1, 2)) + writer.write(Iterator()) + writer.write(Iterator()) + writer.write(Iterator(3, 4)) + writer.write(Iterator()) + + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(1, 2)) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(3, 4)) + assert(readEpoch(reader) == Seq()) + } + + test("blocks waiting for writer") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) + } + } + readRowThread.start() + + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + + // Once we write the epoch the thread should stop waiting and succeed. + writer.write(Iterator(1)) + readRowThread.join(streamingTimeout.toMillis) + } + + test("multiple writer partitions") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(0).write(Iterator(1, 4, 7)) + writers(1).write(Iterator(2, 5)) + writers(2).write(Iterator(3, 6)) + + writers(0).write(Iterator(4, 7, 10)) + writers(1).write(Iterator(5, 8)) + writers(2).write(Iterator(6, 9)) + + // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. + // The epochs should be deterministically preserved, however. + assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) + assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) + } + + test("reader epoch only ends when all writer partitions write it") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(1).write(Iterator()) + writers(2).write(Iterator()) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!readerEpoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + writers(0).write(Iterator()) + readEpochMarkerThread.join(streamingTimeout.toMillis) + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index af4618bed5456..aeef4c8fe9332 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -17,72 +17,74 @@ package org.apache.spark.sql.streaming.sources -import java.util.Optional - -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReader() extends MicroBatchReader with ContinuousReader { - def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} - def getStartOffset: Offset = RateStreamOffset(Map()) - def getEndOffset: Offset = RateStreamOffset(Map()) - def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) - def commit(end: Offset): Unit = {} - def readSchema(): StructType = StructType(Seq()) - def stop(): Unit = {} - def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - def setStartOffset(start: Optional[Offset]): Unit = {} - - def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { +case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSupport { + override def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + override def fullSchema(): StructType = StructType(Seq()) + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = null + override def initialOffset(): Offset = RateStreamOffset(Map()) + override def latestOffset(): Offset = RateStreamOffset(Map()) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { throw new IllegalStateException("fake source - cannot actually read") } } -trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { - override def createMicroBatchReader( - schema: Optional[StructType], +trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = FakeReader() + options: DataSourceOptions): MicroBatchReadSupport = FakeReadSupport() } -trait FakeContinuousReadSupport extends ContinuousReadSupport { - override def createContinuousReader( - schema: Optional[StructType], +trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = FakeReader() + options: DataSourceOptions): ContinuousReadSupport = FakeReadSupport() } -trait FakeStreamWriteSupport extends StreamWriteSupport { - override def createStreamWriter( +trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { + options: DataSourceOptions): StreamingWriteSupport = { throw new IllegalStateException("fake sink - cannot actually write") } } -class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupportProvider { override def shortName(): String = "fake-read-microbatch-only" } -class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-continuous-only" } class FakeReadBothModes extends DataSourceRegister - with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + with FakeMicroBatchReadSupportProvider with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-microbatch-continuous" } @@ -90,7 +92,7 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { +class FakeWriteSupportProvider extends DataSourceRegister with FakeStreamingWriteSupportProvider { override def shortName(): String = "fake-write-microbatch-continuous" } @@ -105,8 +107,8 @@ class FakeSink extends Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } -class FakeWriteV1Fallback extends DataSourceRegister - with FakeStreamWriteSupport with StreamSinkProvider { +class FakeWriteSupportProviderV1Fallback extends DataSourceRegister + with FakeStreamingWriteSupportProvider with StreamSinkProvider { override def createSink( sqlContext: SQLContext, @@ -189,11 +191,11 @@ class StreamingDataSourceV2Suite extends StreamTest { val v2Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteV1Fallback]) + .isInstanceOf[FakeWriteSupportProviderV1Fallback]) // Ensure we create a V1 sink with the config. Note the config is a comma separated // list, including other fake entries. - val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" + val fullSinkName = classOf[FakeWriteSupportProviderV1Fallback].getName withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { val v1Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) @@ -217,35 +219,37 @@ class StreamingDataSourceV2Suite extends StreamTest { val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) + case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => + case (_: ContinuousReadSupportProvider, _: StreamingWriteSupportProvider, + _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all case (r, _, _) - if !r.isInstanceOf[MicroBatchReadSupport] - && !r.isInstanceOf[ContinuousReadSupport] => + if !r.isInstanceOf[MicroBatchReadSupportProvider] + && !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => + case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") // Invalid - trigger is continuous but reader is not - case (r, _: StreamWriteSupport, _: ContinuousTrigger) - if !r.isInstanceOf[ContinuousReadSupport] => + case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") // Invalid - trigger is microbatch but reader is not case (r, _, t) - if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + if !r.isInstanceOf[MicroBatchReadSupportProvider] && + !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, s"Data source $read does not support microbatch processing") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 14b1feb2adc20..b65058fffd339 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(LastOptions.parameters("doubleOpt") == "6.7") } - test("check jdbc() does not support partitioning or bucketing") { + test("check jdbc() does not support partitioning, bucketBy or sortBy") { val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath) var w = df.write.partitionBy("value") @@ -287,7 +287,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) - Seq("jdbc", "bucketing").foreach { s => + Seq("jdbc", "does not support bucketBy right now").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("sortBy must be used together with bucketBy").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.bucketBy(2, "value").sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "does not support bucketBy and sortBy right now").foreach { s => assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e52152..615923fe02d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -136,6 +136,19 @@ private[sql] trait SQLTestData { self => df } + protected lazy val lowerCaseDataWithDuplicates: DataFrame = { + val df = spark.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.createOrReplaceTempView("lowerCaseData") + df + } + protected lazy val arrayData: RDD[ArrayData] = { val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: @@ -255,6 +268,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val trainingSales: DataFrame = { + val df = spark.sparkContext.parallelize( + TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) :: + TrainingSales("Experts", CourseSales("JAVA", 2012, 20000)) :: + TrainingSales("Dummies", CourseSales("dotNet", 2012, 5000)) :: + TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) :: + TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF() + df.createOrReplaceTempView("trainingSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -310,4 +334,5 @@ private[sql] object SQLTestData { case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) case class CourseSales(course: String, year: Int, earnings: Double) + case class TrainingSales(training: String, sales: CourseSales) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index bc4a120f7042f..2fb8f70a20791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -76,7 +76,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with /** * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * ConsoleAppender's `follow` should be set to `true` so that it will honor reassignments of * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if * we change System.out and System.err. */ @@ -391,6 +391,13 @@ private[sql] trait SQLTestUtilsBase val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.makeQualified(hadoopPath).toUri } + + /** + * Returns full path to the given file in the resource folder + */ + protected def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } } private[sql] object SQLTestUtils { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index c5ade65283045..10000f12ab329 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -18,6 +18,8 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -26,6 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import javax.net.ssl.SSLServerSocket; import javax.security.auth.login.LoginException; @@ -92,7 +95,30 @@ public String getAuthName() { public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; public static final String HS2_CLIENT_TOKEN = "hiveserver2ClientToken"; - public HiveAuthFactory(HiveConf conf) throws TTransportException { + private static Field keytabFile = null; + private static Method getKeytab = null; + static { + Class clz = UserGroupInformation.class; + try { + keytabFile = clz.getDeclaredField("keytabFile"); + keytabFile.setAccessible(true); + } catch (NoSuchFieldException nfe) { + LOG.debug("Cannot find private field \"keytabFile\" in class: " + + UserGroupInformation.class.getCanonicalName(), nfe); + keytabFile = null; + } + + try { + getKeytab = clz.getDeclaredMethod("getKeytab"); + getKeytab.setAccessible(true); + } catch(NoSuchMethodException nme) { + LOG.debug("Cannot find private method \"getKeytab\" in class:" + + UserGroupInformation.class.getCanonicalName(), nme); + getKeytab = null; + } + } + + public HiveAuthFactory(HiveConf conf) throws TTransportException, IOException { this.conf = conf; transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); @@ -107,9 +133,16 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { authTypeStr = AuthTypes.NONE.getAuthName(); } if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - saslServer = ShimLoader.getHadoopThriftAuthBridge() - .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); + String principal = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL); + String keytab = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); + if (needUgiLogin(UserGroupInformation.getCurrentUser(), + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keytab)) { + saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(principal, keytab); + } else { + // Using the default constructor to avoid unnecessary UGI login. + saslServer = new HadoopThriftAuthBridge.Server(); + } + // start delegation token manager try { // rawStore is only necessary for DBTokenStore @@ -362,4 +395,25 @@ public static void verifyProxyAccess(String realUser, String proxyUser, String i } } + public static boolean needUgiLogin(UserGroupInformation ugi, String principal, String keytab) { + return null == ugi || !ugi.hasKerberosCredentials() || !ugi.getUserName().equals(principal) || + !Objects.equals(keytab, getKeytabFromUgi()); + } + + private static String getKeytabFromUgi() { + synchronized (UserGroupInformation.class) { + try { + if (keytabFile != null) { + return (String) keytabFile.get(null); + } else if (getKeytab != null) { + return (String) getKeytab.invoke(UserGroupInformation.getCurrentUser()); + } else { + return null; + } + } catch (Exception e) { + LOG.debug("Fail to get keytabFile path via reflection", e); + return null; + } + } + } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java index 2e21f18d61268..adb269aa235ea 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java @@ -349,7 +349,7 @@ public void addValue(Type type, Object field) { break; case FLOAT_TYPE: nulls.set(size, field == null); - doubleVars()[size] = field == null ? 0 : ((Float)field).doubleValue(); + doubleVars()[size] = field == null ? 0 : new Double(field.toString()); break; case DOUBLE_TYPE: nulls.set(size, field == null); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index f59cdcd3188e6..745f385e87f78 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -471,7 +471,7 @@ private OperationHandle executeStatementInternal(String statement, Map currentDB) + cli.printMasterAndAppId + var currentPrompt = promptWithCurrentDB var line = reader.readLine(currentPrompt + "> ") @@ -300,10 +302,6 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) - if (sessionState.getIsSilent) { - Logger.getRootLogger.setLevel(Level.WARN) - } - private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } @@ -315,6 +313,9 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // because the Hive unit tests do not go through the main() code path. if (!isRemoteMode) { SparkSQLEnv.init() + if (sessionState.getIsSilent) { + SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + } } else { // Hive 1.2 + not supported in CLI throw new RuntimeException("Remote operations not supported") @@ -324,6 +325,12 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { hiveVariables.asScala.foreach(kv => SparkSQLEnv.sqlContext.conf.setConfString(kv._1, kv._2)) } + def printMasterAndAppId(): Unit = { + val master = SparkSQLEnv.sparkContext.master + val appId = SparkSQLEnv.sparkContext.applicationId + console.printInfo(s"Spark master: $master, Application Id: $appId") + } + override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ad1f5eb9ca3a7..1335e16e35882 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,7 +27,7 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.Utils -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory @@ -52,8 +52,22 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC if (UserGroupInformation.isSecurityEnabled) { try { - HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = Utils.getUGI() + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB) + if (principal.isEmpty || keyTabFile.isEmpty) { + throw new IOException( + "HiveServer2 Kerberos principal or keytab is not correctly configured") + } + + val originalUgi = UserGroupInformation.getCurrentUser + sparkServiceUGI = if (HiveAuthFactory.needUgiLogin(originalUgi, + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)) { + HiveAuthFactory.loginFromKeytab(hiveConf) + Utils.getUGI() + } else { + originalUgi + } + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index cbd75ad12d430..8980bcf885589 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,7 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index f517bffccdf31..771104ceb8842 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,10 +47,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {listener.getOnlineSessionNum} session(s) are online, running {listener.getTotalRunning} SQL statement(s) ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + generateSessionStatsTable(request) ++ + generateSQLStatsTable(request) } - UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -67,7 +67,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", @@ -76,7 +76,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } @@ -138,7 +139,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { + private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = { val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { @@ -146,8 +147,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) + val sessionLink = "%s/%s/session/?id=%s".format( + UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 5cd2fdf6437c2..163eb43aabc72 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -56,9 +56,9 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Session created at {formatDate(sessionStat.startTimestamp)}, Total run {sessionStat.totalExecution} SQL ++ - generateSQLStatsTable(sessionStat.sessionId) + generateSQLStatsTable(request, sessionStat.sessionId) } - UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -75,7 +75,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(sessionID: String): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = { val executionList = listener.getExecutionList .filter(_.sessionId == sessionID) val numStatement = executionList.size @@ -86,7 +86,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 192f33a45e273..70eb28cdd0c64 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -636,6 +636,14 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(pipeoutFileList(sessionID).length == 0) } } + + test("SPARK-24829 Checks cast as float") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SELECT CAST('4.56' AS FLOAT)") + resultSet.next() + assert(resultSet.getString(1) === "4.56") + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -766,6 +774,14 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } + + test("SPARK-24829 Checks cast as float") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SELECT CAST('4.56' AS FLOAT)") + resultSet.next() + assert(resultSet.getString(1) === "4.56") + } + } } object ServerMode extends Enumeration { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index cebaad5b4ad9b..b9b2b7dbf38e8 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { @@ -59,6 +60,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) + // Ensure that limit operation returns rows in the same order as Hive + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") @@ -73,6 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 28c340a176d91..5cc1047fc067b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -114,7 +114,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * should interpret these special data source properties and restore the original table metadata * before returning it. */ - private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient { + private[hive] def getRawTable(db: String, table: String): CatalogTable = { client.getTable(db, table) } @@ -138,17 +138,37 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Checks the validity of data column names. Hive metastore disallows the table to use comma in - * data column names. Partition columns do not have such a restriction. Views do not have such - * a restriction. + * Checks the validity of data column names. Hive metastore disallows the table to use some + * special characters (',', ':', and ';') in data column names, including nested column names. + * Partition columns do not have such a restriction. Views do not have such a restriction. */ private def verifyDataSchema( tableName: TableIdentifier, tableType: CatalogTableType, dataSchema: StructType): Unit = { if (tableType != VIEW) { - dataSchema.map(_.name).foreach { colName => - if (colName.contains(",")) { - throw new AnalysisException("Cannot create a table having a column whose name contains " + - s"commas in Hive metastore. Table: $tableName; Column: $colName") + val invalidChars = Seq(",", ":", ";") + def verifyNestedColumnNames(schema: StructType): Unit = schema.foreach { f => + f.dataType match { + case st: StructType => verifyNestedColumnNames(st) + case _ if invalidChars.exists(f.name.contains) => + val invalidCharsString = invalidChars.map(c => s"'$c'").mkString(", ") + val errMsg = "Cannot create a table having a nested column whose name contains " + + s"invalid characters ($invalidCharsString) in Hive metastore. Table: $tableName; " + + s"Column: ${f.name}" + throw new AnalysisException(errMsg) + case _ => + } + } + + dataSchema.foreach { f => + f.dataType match { + // Checks top-level column names + case _ if f.name.contains(",") => + throw new AnalysisException("Cannot create a table having a column whose name " + + s"contains commas in Hive metastore. Table: $tableName; Column: ${f.name}") + // Checks nested column names + case st: StructType => + verifyNestedColumnNames(st) + case _ => } } } @@ -158,13 +178,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -177,7 +197,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -211,7 +231,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -480,7 +500,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -489,7 +509,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = withClient { @@ -540,7 +560,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { + override def alterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -624,7 +644,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * data schema should not have conflict column names with the existing partition columns, and * should still contain all the existing data columns. */ - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = withClient { @@ -656,7 +676,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { @@ -765,9 +785,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // schema we read back is different(ignore case and nullability) from the one in table // properties which was written when creating table, we should respect the table schema // from hive. - logWarning(s"The table schema given by Hive metastore(${table.schema.simpleString}) is " + + logWarning(s"The table schema given by Hive metastore(${table.schema.catalogString}) is " + "different from the schema when this table was created by Spark SQL" + - s"(${schemaFromTableProps.simpleString}). We have to fall back to the table schema " + + s"(${schemaFromTableProps.catalogString}). We have to fall back to the table schema " + "from Hive metastore which is not case preserving.") hiveTable.copy(schemaPreservesCase = false) } @@ -1208,7 +1228,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction( + override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1221,12 +1241,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doDropFunction(db: String, name: String): Unit = withClient { + override def dropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override protected def doAlterFunction( + override def alterFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) @@ -1235,7 +1255,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = withClient { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index e5aff3b99d0b9..de41bb418181d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, ExternalCatalog, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalogBuilder: () => HiveExternalCatalog, + externalCatalogBuilder: () => ExternalCatalog, globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, @@ -175,6 +175,10 @@ private[sql] class HiveSessionCatalog( super.functionExists(name) || hiveFunctions.contains(name.funcName) } + override def isPersistentFunction(name: FunctionIdentifier): Boolean = { + super.isPersistentFunction(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 40b9bb51ca9a0..2882672f327c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner @@ -35,14 +36,14 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { - private def externalCatalog: HiveExternalCatalog = - session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private def externalCatalog: ExternalCatalogWithListener = session.sharedState.externalCatalog /** * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - new HiveSessionResourceLoader(session, () => externalCatalog.client) + new HiveSessionResourceLoader( + session, () => externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8df05cbb20361..9fe83bb332a9a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -87,7 +87,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -114,7 +114,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -145,7 +145,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, @@ -186,15 +186,28 @@ case class RelationConversions( serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. if (serde.contains("parquet")) { - val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = relation.tableMeta.storage.properties + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { sessionCatalog.metastoreCatalog.convertToLogicalRelation( relation, @@ -212,7 +225,7 @@ case class RelationConversions( } override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { + plan resolveOperators { // Write path case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists) // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 10c9603745379..cd321d41f43e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.2.") + s"0.12.0 through 2.3.3.") .stringConf .createWithDefault(builtinHiveVersion) @@ -105,11 +105,10 @@ private[spark] object HiveUtils extends Logging { .createWithDefault(false) val CONVERT_METASTORE_ORC = buildConf("spark.sql.hive.convertMetastoreOrc") - .internal() .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b5444a4217924..7d57389947576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -110,8 +110,9 @@ class HadoopTableReader( deserializerClass: Class[_ <: Deserializer], filterOpt: Option[PathFilter]): RDD[InternalRow] = { - assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, - since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") + assert(!hiveTable.isPartitioned, + "makeRDDForTable() cannot be called on a partitioned table, since input formats may " + + "differ across partitions. Use makeRDDForPartitionedTable() instead.") // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index da9fe2d3088b4..02c1ed93eb2f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -353,15 +353,19 @@ private[hive] class HiveClientImpl( client.getDatabasesByPattern(pattern).asScala } + private def getRawTableOption(dbName: String, tableName: String): Option[HiveTable] = { + Option(client.getTable(dbName, tableName, false /* do not throw exception */)) + } + override def tableExists(dbName: String, tableName: String): Boolean = withHiveState { - Option(client.getTable(dbName, tableName, false /* do not throw exception */)).nonEmpty + getRawTableOption(dbName, tableName).nonEmpty } override def getTableOption( dbName: String, tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") - Option(client.getTable(dbName, tableName, false)).map { h => + getRawTableOption(dbName, tableName).map { h => // Note: Hive separates partition columns and the schema, but for us the // partition columns are part of the schema val cols = h.getCols.asScala.map(fromHiveColumn) @@ -923,6 +927,9 @@ private[hive] object HiveClientImpl { case CatalogTableType.MANAGED => HiveTableType.MANAGED_TABLE case CatalogTableType.VIEW => HiveTableType.VIRTUAL_VIEW + case t => + throw new IllegalArgumentException( + s"Unknown table type is found at toHiveTable: $t") }) // Note: In Hive the schema and partition columns must be disjoint sets val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => @@ -995,6 +1002,8 @@ private[hive] object HiveClientImpl { tpart.setTableName(ht.getTableName) tpart.setValues(partValues.asJava) tpart.setSd(storageDesc) + tpart.setCreateTime((p.createTime / 1000).toInt) + tpart.setLastAccessTime((p.lastAccessTime / 1000).toInt) tpart.setParameters(mutable.Map(p.parameters.toSeq: _*).asJava) new HivePartition(ht, tpart) } @@ -1019,6 +1028,8 @@ private[hive] object HiveClientImpl { compressed = apiPartition.getSd.isCompressed, properties = Option(apiPartition.getSd.getSerdeInfo.getParameters) .map(_.asScala.toMap).orNull), + createTime = apiPartition.getCreateTime.toLong * 1000, + lastAccessTime = apiPartition.getLastAccessTime.toLong * 1000, parameters = properties, stats = readHiveStats(properties)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 948ba542b5733..bc9d4cd7f4181 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -46,7 +45,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -343,7 +342,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { - conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000L } override def loadPartition( @@ -599,6 +598,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { object ExtractableLiteral { def unapply(expr: Expression): Option[String] = expr match { + case Literal(null, _) => None // `null`s can be cast as other types; we want to avoid NPEs. case Literal(value, _: IntegralType) => Some(value.toString) case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) case _ => None @@ -607,7 +607,23 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { object ExtractableLiterals { def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { - val extractables = exprs.map(ExtractableLiteral.unapply) + // SPARK-24879: The Hive metastore filter parser does not support "null", but we still want + // to push down as many predicates as we can while still maintaining correctness. + // In SQL, the `IN` expression evaluates as follows: + // > `1 in (2, NULL)` -> NULL + // > `1 in (1, NULL)` -> true + // > `1 in (2)` -> false + // Since Hive metastore filters are NULL-intolerant binary operations joined only by + // `AND` and `OR`, we can treat `NULL` as `false` and thus rewrite `1 in (2, NULL)` as + // `1 in (2)`. + // If the Hive metastore begins supporting NULL-tolerant predicates and Spark starts + // pushing down these predicates, then this optimization will become incorrect and need + // to be changed. + val extractables = exprs + .filter { + case Literal(null, _) => false + case _ => true + }.map(ExtractableLiteral.unapply) if (extractables.nonEmpty && extractables.forall(_.isDefined)) { Some(extractables.map(_.get)) } else { @@ -657,17 +673,32 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled + object ExtractAttribute { + def unapply(expr: Expression): Option[Attribute] = { + expr match { + case attr: Attribute => Some(attr) + case Cast(child @ AtomicType(), dt: AtomicType, _) + if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) + case _ => None + } + } + } + def convert(expr: Expression): Option[String] = expr match { - case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced => + case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced => + case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) => + case op @ SpecialBinaryComparison( + ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) => Some(s"$name ${op.symbol} $value") - case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) => + case op @ SpecialBinaryComparison( + ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) => Some(s"$value ${op.symbol} $name") case And(expr1, expr2) if useAdvanced => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index c2690ec32b9e7..6a90c44a2633d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -98,7 +98,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 } private def downloadVersion( @@ -182,6 +182,7 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 + name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 681ee9200f02b..25e9886fa6576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.2", + case object v2_3 extends HiveVersion("2.3.3", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 1e801fe1845c4..27d807cc35627 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommand /** * Create table and insert the query result into it. * - * @param tableDesc the Table Describe, which may contains serde, storage handler etc. + * @param tableDesc the Table Describe, which may contain serde, storage handler etc. * @param query the query whose result will be insert into the new relation * @param mode SaveMode */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 7dcaf170f9693..b3795b4430404 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -62,6 +62,8 @@ case class HiveTableScanExec( override def conf: SQLConf = sparkSession.sessionState.conf + override def nodeName: String = s"Scan hive ${relation.tableMeta.qualifiedName}" + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -78,9 +80,9 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => - require( - pred.dataType == BooleanType, - s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + require(pred.dataType == BooleanType, + s"Data type of predicate $pred must be ${BooleanType.catalogString} rather than " + + s"${pred.dataType.catalogString}.") BindReferences.bindReference(pred, relation.partitionCols) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 6a7b25b36d9a5..e0f7375387d24 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -122,7 +122,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { allSupportedHiveVersions) val externalCatalog = sparkSession.sharedState.externalCatalog - val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version + val hiveVersion = externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.version val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 237ed9bc05988..de8085f07db19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration /** @@ -72,6 +72,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val configuration = job.getConfiguration @@ -121,6 +122,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => @@ -162,7 +164,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + Option(TaskContext.get()) + .foreach(_.addTaskCompletionListener[Unit](_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s OrcFileFormat.unwrapOrcStructs( @@ -174,6 +177,23 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 80e44ca504356..713b70f252b6a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -92,11 +92,12 @@ private[hive] object OrcFileOperator extends Logging { : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.flatMap(getFileReader(_, conf, ignoreCorruptFiles)).headOption.map { reader => - val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] - val schema = readerInspector.getTypeName - logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") - CatalystSqlParser.parseDataType(schema).asInstanceOf[StructType] + paths.toIterator.map(getFileReader(_, conf, ignoreCorruptFiles)).collectFirst { + case Some(reader) => + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema).asInstanceOf[StructType] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 965aea2b61456..ee3f99ab7e9bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -83,11 +84,11 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: TestHiveExternalCatalog = { - new TestHiveExternalCatalog( + override lazy val externalCatalog: ExternalCatalogWithListener = { + new ExternalCatalogWithListener(new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, - hiveClient) + hiveClient)) } } @@ -208,7 +209,9 @@ private[hive] class TestHiveSparkSession( new TestHiveSessionStateBuilder(this, parentSessionState).build() } - lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + lazy val metadataHive: HiveClient = { + sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() + } override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index a8cbd4fab15bb..48891fdcb1d80 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -676,7 +676,7 @@ public int compareTo(Complex other) { } int lastComparison = 0; - Complex typedOther = (Complex)other; + Complex typedOther = other; lastComparison = Boolean.valueOf(isSetAint()).compareTo(typedOther.isSetAint()); if (lastComparison != 0) { diff --git a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py index 341a1b40e07af..5b360208d36f6 100644 --- a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py +++ b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py @@ -18,6 +18,9 @@ # import sys +if sys.version_info[0] >= 3: + xrange = range + for i in xrange(50): for j in xrange(5): for k in xrange(20022): diff --git a/sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 b/sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 index 06461b525b058..967e2d3956414 100644 --- a/sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 +++ b/sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 @@ -1 +1 @@ -instr(str, substr) - Returns the index of the first occurance of substr in str +instr(str, substr) - Returns the index of the first occurrence of substr in str diff --git a/sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 b/sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 index 5a8c34271f443..0a745342a4ce9 100644 --- a/sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 +++ b/sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 @@ -1,4 +1,4 @@ -instr(str, substr) - Returns the index of the first occurance of substr in str +instr(str, substr) - Returns the index of the first occurrence of substr in str Example: > SELECT instr('Facebook', 'boo') FROM src LIMIT 1; 5 diff --git a/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e b/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e index 84bea329540d1..8e70b0c89b594 100644 --- a/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e +++ b/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e @@ -1 +1 @@ -locate(substr, str[, pos]) - Returns the position of the first occurance of substr in str after position pos +locate(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos diff --git a/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 b/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 index 092e12586b9e8..e103255a31f03 100644 --- a/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 +++ b/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 @@ -1,4 +1,4 @@ -locate(substr, str[, pos]) - Returns the position of the first occurance of substr in str after position pos +locate(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos Example: > SELECT locate('bar', 'foobarbar', 5) FROM src LIMIT 1; 7 diff --git a/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 b/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 index 9ced4ee32cf0b..6caa4b679111d 100644 --- a/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 +++ b/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 @@ -6,8 +6,8 @@ translate('abcdef', 'adc', '19') returns '1b9ef' replacing 'a' with '1', 'd' wit translate('a b c d', ' ', '') return 'abcd' removing all spaces from the input string -If the same character is present multiple times in the input string, the first occurence of the character is the one that's considered for matching. However, it is not recommended to have the same character more than once in the from string since it's not required and adds to confusion. +If the same character is present multiple times in the input string, the first occurrence of the character is the one that's considered for matching. However, it is not recommended to have the same character more than once in the from string since it's not required and adds to confusion. For example, -translate('abcdef', 'ada', '192') returns '1bc9ef' replaces 'a' with '1' and 'd' with '9' ignoring the second occurence of 'a' in the from string mapping it to '2' +translate('abcdef', 'ada', '192') returns '1bc9ef' replaces 'a' with '1' and 'd' with '9' ignoring the second occurrence of 'a' in the from string mapping it to '2' diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q index 965b0b7ed0a3e..633150b5cf544 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q @@ -43,7 +43,7 @@ analyze table loc_orc compute statistics for columns state,locid,zip,year; -- dept_orc - 4 -- loc_orc - 8 --- count distincts for relevant columns (since count distinct values are approximate in some cases count distint values will be greater than number of rows) +-- count distincts for relevant columns (since count distinct values are approximate in some cases count distinct values will be greater than number of rows) -- emp_orc.deptid - 3 -- emp_orc.lastname - 7 -- dept_orc.deptid - 6 diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q index da2e26fde7069..e8289772e7544 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q @@ -26,7 +26,7 @@ set hive.optimize.bucketmapjoin.sortedmerge=true; -- Since size is being used to find the big table, the order of the tables in the join does not matter -- The tables are only bucketed and not sorted, the join should not be converted --- Currenly, a join is only converted to a sort-merge join without a hint, automatic conversion to +-- Currently, a join is only converted to a sort-merge join without a hint, automatic conversion to -- bucketized mapjoin is not done explain extended select count(*) FROM bucket_small a JOIN bucket_big b ON a.key = b.key; select count(*) FROM bucket_small a JOIN bucket_big b ON a.key = b.key; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q index 6fe5117026ce8..e4ed7195a0575 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q @@ -69,5 +69,5 @@ SELECT * FROM episodes_partitioned WHERE doctor_pt > 6 ORDER BY air_date; SELECT * FROM episodes_partitioned ORDER BY air_date LIMIT 5; -- Fetch w/filter to specific partition SELECT * FROM episodes_partitioned WHERE doctor_pt = 6; --- Fetch w/non-existant partition +-- Fetch w/non-existent partition SELECT * FROM episodes_partitioned WHERE doctor_pt = 7 LIMIT 5; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q index 0c9f1b86a9e97..39d2d248a311f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q @@ -22,7 +22,7 @@ SELECT key + (value/2) FROM DECIMAL_UDF; EXPLAIN SELECT key + '1.0' FROM DECIMAL_UDF; SELECT key + '1.0' FROM DECIMAL_UDF; --- substraction +-- subtraction EXPLAIN SELECT key - key FROM DECIMAL_UDF; SELECT key - key FROM DECIMAL_UDF; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q index 3aeae0d5c33d6..d677fe65245ed 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q @@ -13,7 +13,7 @@ INSERT OVERWRITE TABLE dest1 SELECT substr(src.key,1,1), count(DISTINCT substr(s SELECT dest1.* FROM dest1 ORDER BY key; --- HIVE-5560 when group by key is used in distinct funtion, invalid result are returned +-- HIVE-5560 when group by key is used in distinct function, invalid result are returned EXPLAIN FROM src diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q index f53295e4b2435..69d671aa47116 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q @@ -12,7 +12,7 @@ LOAD DATA LOCAL INPATH '../../data/files/T1.txt' INTO TABLE T1 PARTITION (ds='1' INSERT OVERWRITE TABLE T1 PARTITION (ds='1') select key, val from T1 where ds = '1'; -- The plan is not converted to a map-side, since although the sorting columns and grouping --- columns match, the user is issueing a distinct. +-- columns match, the user is issuing a distinct. -- However, after HIVE-4310, partial aggregation is performed on the mapper EXPLAIN select count(distinct key) from T1; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 48ab4eb9a6178..569f00c053e5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -38,7 +38,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val plan = table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index d10a6f25c64fc..4550d350f6db2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -268,12 +268,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compressionCodecs = compressCodecs, tableCompressionCodecs = compressCodecs) { case (tableCodec, sessionCodec, realCodec, tableSize) => - // For non-partitioned table and when convertMetastore is true, Expect session-level - // take effect, and in other cases expect table-level take effect - // TODO: It should always be table-level taking effect when the bug(SPARK-22926) - // is fixed - val expectCodec = - if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + val expectCodec = tableCodec.get assert(expectCodec == realCodec) assert(checkTableSize( format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 0a522b6a11c80..1de258f060943 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,4 +113,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } + + test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { + val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + assert(!client1.equals(client2)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6ca58e68d31eb..5103aa8a207db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -56,39 +56,47 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { } private def tryDownloadSpark(version: String, path: String): Unit = { - // Try mirrors a few times until one succeeds - for (i <- 0 until 3) { - // we don't retry on a failure to get mirror url. If we can't get a mirror url, - // the test fails (getStringFromUrl will throw an exception) - val preferredMirror = - getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + // Try a few mirrors first; fall back to Apache archive + val mirrors = + (0 until 2).flatMap { _ => + try { + Some(getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true")) + } catch { + // If we can't get a mirror URL, skip it. No retry. + case _: Exception => None + } + } + val sites = mirrors.distinct :+ "https://archive.apache.org/dist" + logInfo(s"Trying to download Spark $version from $sites") + for (site <- sites) { val filename = s"spark-$version-bin-hadoop2.7.tgz" - val url = s"$preferredMirror/spark/spark-$version/$filename" + val url = s"$site/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) - return + val downloaded = new File(sparkTestingDir, filename).getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + Seq("rm", downloaded).! + + // For a corrupted file, `tar` returns non-zero values. However, we also need to check + // the extracted file because `tar` returns 0 for empty file. + val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit") + if (exitCode == 0 && sparkSubmit.exists()) { + return + } else { + Seq("rm", "-rf", targetDir).! + } } catch { - case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) + case ex: Exception => + logWarning(s"Failed to download Spark $version from $url: ${ex.getMessage}") } } fail(s"Unable to download Spark $version") } - - private def downloadSpark(version: String): Unit = { - tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) - - val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath - val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath - - Seq("mkdir", targetDir).! - - Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! - - Seq("rm", downloaded).! - } - private def genDataDir(name: String): String = { new File(tmpDataDir, name).getCanonicalPath } @@ -161,7 +169,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") if (!sparkHome.exists()) { - downloadSpark(version) + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } val args = Seq( @@ -195,7 +203,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") + val testingVersions = Seq("2.1.3", "2.2.2", "2.3.1") protected var spark: SparkSession = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala index 285f35b0b0eac..fd5f47e428239 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -26,7 +26,7 @@ class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveS private val externalCatalog = { val catalog = spark.sharedState.externalCatalog - catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.reset() catalog } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ba9b944e4a055..688b619cd1bb5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias @@ -62,7 +62,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { spark.sql("create view vw1 as select 1 as id") val plan = spark.sql("select id from vw1").queryExecution.analyzed val aliases = plan.collect { - case x @ SubqueryAlias("vw1", _) => x + case x @ SubqueryAlias(AliasIdentifier("vw1", Some("default")), _) => x } assert(aliases.size == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index f2d27671094d7..51a48a20daaa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -50,7 +50,8 @@ class HiveSchemaInferenceSuite FileStatusCache.resetForTesting() } - private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val externalCatalog = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] private val client = externalCatalog.client // Return a copy of the given schema with all field names converted to lower case. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index ecc09cdcdbeaf..a3579862c9e59 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -44,7 +44,8 @@ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { val conf = sparkSession.sparkContext.hadoopConfiguration val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) sparkSession.cloneSession() - sparkSession.sharedState.externalCatalog.client.newSession() + sparkSession.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.newSession() val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) assert(oldValue == newValue, "cloneSession and then newSession should not affect the Derby directory") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 079fe45860544..aa5b531992613 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -354,7 +354,7 @@ object SetMetastoreURLTest extends Logging { // HiveExternalCatalog is used when Hive support is enabled. val actualMetastoreURL = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") @@ -780,7 +780,8 @@ object SPARK_18360 { val defaultDbLocation = spark.catalog.getDatabase("default").locationUri assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val hiveClient = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client try { val tableMeta = CatalogTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 95f192f0e40e2..1e525c46a9cfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -32,13 +33,22 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto import spark.implicits._ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) } + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS metadata_only") + } finally { + super.afterAll() + } + } + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the number of matching partitions @@ -50,7 +60,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto } test("SPARK-23877: filter on projected expression") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index fad81c7e9474e..34ca790299859 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -288,8 +288,24 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } } + test("SPARK-24911: keep quotes for nested fields") { + withTable("t1") { + val createTable = "CREATE TABLE `t1`(`a` STRUCT<`b`: STRING>)" + sql(createTable) + val shownDDL = sql(s"SHOW CREATE TABLE t1") + .head() + .getString(0) + .split("\n") + .head + assert(shownDDL == createTable) + + checkCreateTable("t1") + } + } + private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.runSqlHive(ddl) + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) } private def checkCreateTable(table: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 61cec82984795..d8ffb29a59317 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -25,13 +25,14 @@ import scala.util.matching.Regex import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.command.{CommandUtils, DDLUtils} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.HiveExternalCatalog._ @@ -148,6 +149,26 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("SPARK-24626 parallel file listing in Stats computation") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "2", + SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION.key -> "True") { + val checkSizeTable = "checkSizeTable" + withTable(checkSizeTable) { + sql(s"CREATE TABLE $checkSizeTable (key STRING, value STRING) PARTITIONED BY (ds STRING)") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-01') SELECT * FROM src") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-02') SELECT * FROM src") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-03') SELECT * FROM src") + val tableMeta = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(checkSizeTable)) + HiveCatalogMetrics.reset() + assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 0) + val size = CommandUtils.calculateTotalSize(spark, tableMeta) + assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 1) + assert(size === BigInt(17436)) + } + } + } + test("analyze non hive compatible datasource tables") { val table = "parquet_tab" withTable(table) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 19765695fbcb4..2a4efd0cce6e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -72,6 +72,20 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + filterTest("SPARK-24879 null literals should be ignored for IN constructs", + (a("intcol", IntegerType) in (Literal(1), Literal(null))) :: Nil, + "(intcol = 1)") + + // Applying the predicate `x IN (NULL)` should return an empty set, but since this optimization + // will be applied by Catalyst, this filter converter does not need to account for this. + filterTest("SPARK-24879 IN predicates with only NULLs will not cause a NPE", + (a("intcol", IntegerType) in Literal(null)) :: Nil, + "") + + filterTest("typecast null literals should not be pushed down in simple predicates", + (a("intcol", IntegerType) === Literal(null, IntegerType)) :: Nil, + "") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index f991352b207d4..fa9f753795f65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType} // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { - import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname @@ -46,8 +46,7 @@ class HiveClientSuite(version: String) val hadoopConf = new Configuration() hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) val client = buildClient(hadoopConf) - client - .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") val partitions = for { @@ -66,6 +65,15 @@ class HiveClientSuite(version: String) client } + private def attr(name: String): Attribute = { + client.getTable("default", "test").partitionSchema.fields + .find(field => field.name.equals(name)) match { + case Some(field) => AttributeReference(field.name, field.dataType)() + case None => + fail(s"Illegal name of partition attribute: $name") + } + } + override def beforeAll() { super.beforeAll() client = init(true) @@ -74,7 +82,7 @@ class HiveClientSuite(version: String) test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(parseExpression("ds=20170101"))) + Seq(attr("ds") === 20170101)) assert(filteredPartitions.size == testPartitionCount) } @@ -82,7 +90,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds<=>20170101") { // Should return all partitions where <=> is not supported testMetastorePartitionFiltering( - "ds<=>20170101", + attr("ds") <=> 20170101, 20170101 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -90,7 +98,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101") { testMetastorePartitionFiltering( - "ds=20170101", + attr("ds") === 20170101, 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -100,7 +108,7 @@ class HiveClientSuite(version: String) // Should return all partitions where h=0 because getPartitionsByFilter does not support // comparisons to non-literal values testMetastorePartitionFiltering( - "ds=(20170101 + 1) and h=0", + attr("ds") === (Literal(20170101) + 1) && attr("h") === 0, 20170101 to 20170103, 0 to 0, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -108,15 +116,31 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk='aa'") { testMetastorePartitionFiltering( - "chunk='aa'", + attr("chunk") === "aa", 20170101 to 20170103, 0 to 23, "aa" :: Nil) } + test("getPartitionsByFilter: cast(chunk as int)=1 (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(IntegerType) === 1, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(chunk as boolean)=true (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(BooleanType) === true, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( - "20170101=ds", + Literal(20170101) === attr("ds"), 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -124,7 +148,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 and h=10") { testMetastorePartitionFiltering( - "ds=20170101 and h=10", + attr("ds") === 20170101 && attr("h") === 10, + 20170101 to 20170101, + 10 to 10, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, 10 to 10, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -132,7 +164,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 or ds=20170102") { testMetastorePartitionFiltering( - "ds=20170101 or ds=20170102", + attr("ds") === 20170101 || attr("ds") === 20170102, 20170101 to 20170102, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -140,7 +172,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -148,7 +188,19 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) + }) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)") + { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { @@ -159,7 +211,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil) @@ -167,7 +219,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { @@ -179,26 +231,24 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) // Day 2 should include all hours because we can't build a filter for h<(7+1) val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil) } test("getPartitionsByFilter: " + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) - testMetastorePartitionFiltering( - "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", + testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && + ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), day1 :: day2 :: Nil) } @@ -207,41 +257,41 @@ class HiveClientSuite(version: String) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String]): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String], transform: Expression => Expression): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + transform) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { - testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], transform: Expression => Expression): Unit = { val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), Seq( - transform(parseExpression(filterString)) + transform(filterExpr) )) val expectedPartitionCount = expectedPartitionCubes.map { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6176273c88db1..dc96ec416afd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -134,8 +134,8 @@ class VersionsSuite extends SparkFunSuite with Logging { client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) - assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .version.fullVersion.startsWith(version)) + assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ae675149df5e2..c65bf7c14c7a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1005,6 +1005,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te ) ) } + + test("SPARK-24957: average with decimal followed by aggregation returning wrong result") { + val df = Seq(("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("11.9999999988"))).toDF("text", "number") + val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res")) + val agg2 = agg1.groupBy($"text").agg(sum($"avg_res")) + checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142860000"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c85db78c732de..6708a50a961fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.net.URI +import java.util.Date import scala.language.existentials @@ -33,11 +34,13 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -57,7 +60,8 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA protected override def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean): CatalogTable = { + isDataSource: Boolean, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable = { val storage = if (isDataSource) { val serde = HiveSerDe.sourceToSerDe("parquet") @@ -81,17 +85,17 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA val metadata = new MetadataBuilder() .putString("key", "value") .build() + val schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") CatalogTable( identifier = name, tableType = CatalogTableType.EXTERNAL, storage = storage, - schema = new StructType() - .add("col1", "int", nullable = true, metadata = metadata) - .add("col2", "string") - .add("a", "int") - .add("b", "int"), + schema = schema.copy( + fields = schema.fields ++ partitionCols.map(StructField(_, IntegerType))), provider = if (isDataSource) Some("parquet") else Some("hive"), - partitionColumnNames = Seq("a", "b"), + partitionColumnNames = partitionCols, createTime = 0L, createVersion = org.apache.spark.SPARK_VERSION, tracksPartitionsInCatalog = true) @@ -780,7 +784,7 @@ class HiveDDLSuite val part1 = Map("a" -> "1", "b" -> "5") val part2 = Map("a" -> "2", "b" -> "6") val root = new Path(catalog.getTableMetadata(tableIdent).location) - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file @@ -1354,7 +1358,8 @@ class HiveDDLSuite val indexName = tabName + "_index" withTable(tabName) { // Spark SQL does not support creating index. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client sql(s"CREATE TABLE $tabName(a int)") try { @@ -1392,7 +1397,8 @@ class HiveDDLSuite val tabName = "tab1" withTable(tabName) { // Spark SQL does not support creating skewed table. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client client.runSqlHive( s""" |CREATE Table $tabName(col1 int, col2 int) @@ -2144,6 +2150,86 @@ class HiveDDLSuite } } + private def getReader(path: String): org.apache.orc.Reader = { + val conf = spark.sessionState.newHadoopConf() + val files = org.apache.spark.sql.execution.datasources.orc.OrcUtils.listOrcFiles(path, conf) + assert(files.length == 1) + val file = files.head + val fs = file.getFileSystem(conf) + val readerOptions = org.apache.orc.OrcFile.readerOptions(conf).filesystem(fs) + org.apache.orc.OrcFile.createReader(file, readerOptions) + } + + test("SPARK-23355 convertMetastoreOrc should not ignore table properties - STORED AS") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(ORC_IMPLEMENTATION.key -> orcImpl, CONVERT_METASTORE_ORC.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS ORC + |TBLPROPERTIES ( + | orc.compress 'ZLIB', + | orc.compress.size '1001', + | orc.row.index.stride '2002', + | hive.exec.orc.default.block.size '3003', + | hive.exec.orc.compression.strategy 'COMPRESSION') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("orc")) + val properties = table.properties + assert(properties.get("orc.compress") == Some("ZLIB")) + assert(properties.get("orc.compress.size") == Some("1001")) + assert(properties.get("orc.row.index.stride") == Some("2002")) + assert(properties.get("hive.exec.orc.default.block.size") == Some("3003")) + assert(properties.get("hive.exec.orc.compression.strategy") == Some("COMPRESSION")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + val reader = getReader(maybeFile.head.getCanonicalPath) + assert(reader.getCompressionKind.name === "ZLIB") + assert(reader.getCompressionSize == 1001) + assert(reader.getRowIndexStride == 2002) + } + } + } + } + } + + test("SPARK-23355 convertMetastoreParquet should not ignore table properties - STORED AS") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS PARQUET + |TBLPROPERTIES ( + | parquet.compression 'GZIP' + |) + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("parquet")) + val properties = table.properties + assert(properties.get("parquet.compression") == Some("GZIP")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + assertCompression(maybeFile, "parquet", "GZIP") + } + } + } + } + test("load command for non local invalid path validation") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING)") @@ -2165,4 +2251,34 @@ class HiveDDLSuite checkAnswer(spark.table("t4"), Row(0, 0)) } } + + test("SPARK-24812: desc formatted table for last access verification") { + withTable("t1") { + sql( + "CREATE TABLE IF NOT EXISTS t1 (c1_int INT, c2_string STRING, c3_float FLOAT)") + val desc = sql("DESC FORMATTED t1").filter($"col_name".startsWith("Last Access")) + .select("data_type") + // check if the last access time doesnt have the default date of year + // 1970 as its a wrong access time + assert(!(desc.first.toString.contains("1970"))) + } + } + + test("SPARK-24681 checks if nested column names do not include ',', ':', and ';'") { + val expectedMsg = "Cannot create a table having a nested column whose name contains invalid " + + "characters (',', ':', ';') in Hive metastore." + + Seq("nested,column", "nested:column", "nested;column").foreach { nestedColumnName => + withTable("t") { + val e = intercept[AnalysisException] { + spark.range(1) + .select(struct(lit(0).as(nestedColumnName)).as("toplevel")) + .write + .format("hive") + .saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 5d56f89c2271c..c349a327694bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -171,20 +171,15 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } - test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { - val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - df.explain(true) + test("SPARK-23034 show relation names in Hive table scan nodes") { + val tableName = "tab" + withTable(tableName) { + sql(s"CREATE TABLE $tableName(c1 int) USING hive") + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + spark.table(tableName).explain(extended = false) + } + assert(output.toString.contains(s"Scan hive default.$tableName")) } - assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( - s"""== Parsed Logical Plan == - |GlobalLimit 1 - |+- LocalLimit 1 - | +- AnalysisBarrier - | +- Aggregate [a#0], [a#0, count(1) AS count#0L] - | +- Project [_1#0 AS a#0, _2#0 AS b#0] - | +- LocalRelation [_1#0, _2#0] - |""".stripMargin)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2ea51791d0f79..b9c32e789a410 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -84,7 +84,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd } // Testing the Broadcast based join for cartesian join (cross join) - // We assume that the Broadcast Join Threshold will works since the src is a small table + // We assume that the Broadcast Join Threshold will work since the src is a small table private val spark_10484_1 = """ | SELECT a.key, b.key | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300 @@ -1177,13 +1177,18 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd assert(spark.table("with_parts").filter($"p" === 2).collect().head == Row(1, 2)) } - val originalValue = spark.sparkContext.hadoopConfiguration.get(modeConfKey, "nonstrict") + // Turn off style check since the following test is to modify hadoop configuration on purpose. + // scalastyle:off hadoopconfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration + + val originalValue = hadoopConf.get(modeConfKey, "nonstrict") try { - spark.sparkContext.hadoopConfiguration.set(modeConfKey, "nonstrict") + hadoopConf.set(modeConfKey, "nonstrict") sql("INSERT OVERWRITE TABLE with_parts partition(p) select 3, 4") assert(spark.table("with_parts").filter($"p" === 4).collect().head == Row(3, 4)) } finally { - spark.sparkContext.hadoopConfiguration.set(modeConfKey, originalValue) + hadoopConf.set(modeConfKey, originalValue) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index cc592cf6ca629..16541295eb453 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -22,21 +22,29 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} +import org.apache.spark.sql.internal.SQLConf /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 // for details. TestHive.reset() } + override def afterAll() { + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) + super.afterAll() + } // Column pruning tests diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 73f83d593bbfb..20c4c36c05091 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1912,11 +1912,60 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql("LOAD DATA LOCAL INPATH '/non-exist-folder/*part*' INTO TABLE load_t") }.getMessage assert(m.contains("LOAD DATA input path does not exist")) + } + } + } - val m2 = intercept[AnalysisException] { - sql(s"LOAD DATA LOCAL INPATH '$path*/*part*' INTO TABLE load_t") + test("Support wildcard character in folderlevel for LOAD DATA LOCAL INPATH") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t_folder_wildcard") { + sql("CREATE TABLE load_t (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '${ + path.substring(0, path.length - 1) + .concat("*") + }/' INTO TABLE load_t") + checkAnswer(sql("SELECT * FROM load_t"), Seq(Row("1"), Row("2"), Row("3"))) + val m = intercept[AnalysisException] { + sql(s"LOAD DATA LOCAL INPATH '${ + path.substring(0, path.length - 1).concat("_invalid_dir") concat ("*") + }/' INTO TABLE load_t") }.getMessage - assert(m2.contains("LOAD DATA input path allows only filename wildcard")) + assert(m.contains("LOAD DATA input path does not exist")) + } + } + } + + test("SPARK-17796 Support wildcard '?'char in middle as part of local file path") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t1") { + sql("CREATE TABLE load_t1 (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '$path/part-r-0000?' INTO TABLE load_t1") + checkAnswer(sql("SELECT * FROM load_t1"), Seq(Row("1"), Row("2"), Row("3"))) + } + } + } + + test("SPARK-17796 Support wildcard '?'char in start as part of local file path") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t2") { + sql("CREATE TABLE load_t2 (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '$path/?art-r-00001' INTO TABLE load_t2") + checkAnswer(sql("SELECT * FROM load_t2"), Seq(Row("1"))) } } } @@ -1967,6 +2016,22 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("column resolution scenarios with hive table") { + val currentDb = spark.catalog.currentDatabase + withTempDatabase { db1 => + try { + spark.catalog.setCurrentDatabase(db1) + spark.sql("CREATE TABLE t1(i1 int) STORED AS parquet") + spark.sql("INSERT INTO t1 VALUES(1)") + checkAnswer(spark.sql(s"SELECT $db1.t1.i1 FROM t1"), Row(1)) + checkAnswer(spark.sql(s"SELECT $db1.t1.i1 FROM $db1.t1"), Row(1)) + checkAnswer(spark.sql(s"SELECT $db1.t1.* FROM $db1.t1"), Row(1)) + } finally { + spark.catalog.setCurrentDatabase(currentDb) + } + } + } + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { withTable("bar") { withTempView("foo") { @@ -2053,7 +2118,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val deleteOnExitField = classOf[FileSystem].getDeclaredField("deleteOnExit") deleteOnExitField.setAccessible(true) - val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val setOfPath = deleteOnExitField.get(fs).asInstanceOf[Set[Path]] val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() @@ -2099,7 +2164,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("orc", "parquet").foreach { format => test(s"SPARK-18355 Read data from a hive table with a new column - $format") { - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client Seq("true", "false").foreach { value => withSQLConf( @@ -2156,4 +2222,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-24085 scalar subquery in partitioning expression") { + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(format) { + withTempPath { tempDir => + sql( + s""" + |CREATE TABLE ${format} (id_value string) + |PARTITIONED BY (id_type string) + |LOCATION '${tempDir.toURI}' + |STORED AS ${format} + """.stripMargin) + sql(s"insert into $format values ('1','a')") + sql(s"insert into $format values ('2','a')") + sql(s"insert into $format values ('3','b')") + sql(s"insert into $format values ('4','b')") + checkAnswer( + sql(s"SELECT * FROM $format WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } + } + } + } + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 5318b4650b01f..5f73b7170c612 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -136,6 +136,25 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } assert(e.getMessage.contains("Subprocess exited with status")) } + + test("SPARK-24339 verify the result after pruning the unused columns") { + val rowsDf = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformationExec( + input = Seq(rowsDf.col("name").expr), + script = "cat", + output = Seq(AttributeReference("name", StringType)()), + child = child, + ioschema = serdeIOSchema + ), + rowsDf.select("name").collect()) + } } private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index d556a030e2186..d84f9a3828207 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -133,4 +135,42 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { Utils.deleteRecursively(location) } } + + test("SPARK-24204 error handling for unsupported data types") { + withTempDir { dir => + val orcDir = new File(dir, "orc").getCanonicalPath + + // write path + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[AnalysisException] { + sql("select null").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index d3fff37c3424d..d50bf0b8fd603 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -30,7 +30,7 @@ trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 53397991e59dc..b9ec940ac4925 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -666,7 +666,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes assert(expectedResult.isRight, s"Was not expecting error with $path: " + e) assert( e.getMessage.contains(expectedResult.right.get), - s"Did not find expected error message wiht $path") + s"Did not find expected error message with $path") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index dce5bb7ddba66..6858bbc441721 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -124,7 +124,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { withTempPath { dir => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index e23edfa506517..4a4d2c5d9d8c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -940,6 +940,11 @@ abstract class DStream[T: ClassTag] ( object DStream { + private val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r + private val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r + private val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r + private val SCALA_CLASS_REGEX = """^scala""".r + // `toPairDStreamFunctions` was in SparkContext before 1.3 and users had to // `import StreamingContext._` to enable it. Now we move it here to make the compiler find // it automatically. However, we still keep the old function in StreamingContext for backward @@ -953,11 +958,6 @@ object DStream { /** Get the creation site of a DStream from the stack trace of when the DStream is created. */ private[streaming] def getCreationSite(): CallSite = { - val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r - val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r - val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r - val SCALA_CLASS_REGEX = """^scala""".r - /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala index 7b29b40668def..8717555dea491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, Utils} /** - * Class that manages executor allocated to a StreamingContext, and dynamically request or kill + * Class that manages executors allocated to a StreamingContext, and dynamically requests or kills * executors based on the statistics of the streaming computation. This is different from the core * dynamic allocation policy; the core policy relies on executors being idle for a while, but the * micro-batch model of streaming prevents any particular executors from being idle for a long @@ -43,6 +43,10 @@ import org.apache.spark.util.{Clock, Utils} * * This features should ideally be used in conjunction with backpressure, as backpressure ensures * system stability, while executors are being readjusted. + * + * Note that an initial set of executors (spark.executor.instances) was allocated when the + * SparkContext was created. This class scales executors up/down after the StreamingContext + * has started. */ private[streaming] class ExecutorAllocationManager( client: ExecutorAllocationClient, @@ -202,12 +206,7 @@ private[streaming] object ExecutorAllocationManager extends Logging { val MAX_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.maxExecutors" def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - val numExecutor = conf.getInt("spark.executor.instances", 0) val streamingDynamicAllocationEnabled = conf.getBoolean(ENABLED_KEY, false) - if (numExecutor != 0 && streamingDynamicAllocationEnabled) { - throw new IllegalArgumentException( - "Dynamic Allocation for streaming cannot be enabled while spark.executor.instances is set.") - } if (Utils.isDynamicAllocationEnabled(conf) && streamingDynamicAllocationEnabled) { throw new IllegalArgumentException( """ @@ -217,7 +216,7 @@ private[streaming] object ExecutorAllocationManager extends Logging { """.stripMargin) } val testing = conf.getBoolean("spark.streaming.dynamicAllocation.testing", false) - numExecutor == 0 && streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing) + streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing) } def createIfEnabled( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index dacff69d55dd2..cf4324578ea87 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -112,10 +112,11 @@ private[streaming] class ReceivedBlockTracker( def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { val streamIdToBlocks = streamIds.map { streamId => - (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + (streamId, getReceivedBlockQueue(streamId).clone()) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + streamIds.foreach(getReceivedBlockQueue(_).clear()) timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 6748dd4ec48e3..884d21d0afdd3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -47,6 +47,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -54,7 +55,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { isFirstRow: Boolean, jobIdWithData: SparkJobIdWithUIData): Seq[Node] = { if (jobIdWithData.jobData.isDefined) { - generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, + generateNormalJobRow(request, outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.jobData.get) } else { generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, @@ -89,6 +90,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * one cell, we use "rowspan" for the first row of an output op. */ private def generateNormalJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -106,7 +108,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { dropWhile(_.failureReason == None).take(1). // get the first info that contains failure flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") - val detailUrl = s"${SparkUIUtils.prependBaseUri(parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + val detailUrl = s"${SparkUIUtils.prependBaseUri( + request, parent.basePath)}/jobs/job/?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". @@ -196,6 +199,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { val formattedOutputOpDuration = @@ -212,6 +216,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } else { val firstRow = generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -221,6 +226,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val tailRows = sparkJobs.tail.map { sparkJob => generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -278,7 +284,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { /** * Generate the job table for the batch. */ - private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { + private def generateJobTable( + request: HttpServletRequest, + batchUIData: BatchUIData): Seq[Node] = { val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId @@ -301,7 +309,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { outputOpWithJobs.map { case (outputOpData, sparkJobs) => - generateOutputOpIdRow(outputOpData, sparkJobs) + generateOutputOpIdRow(request, outputOpData, sparkJobs) } } @@ -364,9 +372,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
      - val content = summary ++ generateJobTable(batchUIData) + val content = summary ++ generateJobTable(request, batchUIData) - SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) + SparkUIUtils.headerSparkPage( + request, s"Details of batch at $formattedBatchTime", content, parent) } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 3a176f64cdd60..4ce661bc1144e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -148,7 +148,7 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val resources = generateLoadResources() + val resources = generateLoadResources(request) val basicInfo = generateBasicInfo() val content = resources ++ basicInfo ++ @@ -156,17 +156,17 @@ private[ui] class StreamingPage(parent: StreamingTab) generateStatTable() ++ generateBatchListTables() } - SparkUIUtils.headerSparkPage("Streaming Statistics", content, parent, Some(5000)) + SparkUIUtils.headerSparkPage(request, "Streaming Statistics", content, parent, Some(5000)) } /** * Generate html that will load css/js files for StreamingPage */ - private def generateLoadResources(): Seq[Node] = { + private def generateLoadResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - + + + // scalastyle:on } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 9d1b82a6341b1..25e71258b9369 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -49,7 +49,7 @@ private[spark] class StreamingTab(val ssc: StreamingContext) def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).removeStaticHandler("/static/streaming") + getSparkUI(ssc).detachHandler("/static/streaming") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index ab7c8558321c8..bba071e80c0e4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -222,7 +222,7 @@ private[streaming] class FileBasedWriteAheadLog( pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) } currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000L) val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString) @@ -312,10 +312,10 @@ private[streaming] object FileBasedWriteAheadLog { handler: I => Iterator[O]): Iterator[O] = { val taskSupport = new ExecutionContextTaskSupport(executionContext) val groupSize = taskSupport.parallelismLevel.max(8) + implicit val ec = executionContext + source.grouped(groupSize).flatMap { group => - val parallelCollection = group.par - parallelCollection.tasksupport = taskSupport - parallelCollection.map(handler) + ThreadUtils.parmap(group)(handler) }.flatten } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 4fa236bd39663..fd7e00b1de25f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -26,10 +26,12 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration +import org.mockito.Matchers.any +import org.mockito.Mockito.{doThrow, reset, spy} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult @@ -115,6 +117,47 @@ class ReceivedBlockTrackerSuite tracker2.stop() } + test("block allocation to batch should not loose blocks from received queue") { + val tracker1 = spy(createTracker()) + tracker1.isWriteAheadLogEnabled should be (true) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + // Add blocks + val blockInfos = generateBlockInfos() + blockInfos.map(tracker1.addBlock) + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Try to allocate the blocks to a batch and verify that it's failing + // The blocks should stay in the received queue when WAL write failing + doThrow(new RuntimeException("Not able to write BatchAllocationEvent")) + .when(tracker1).writeToLog(any(classOf[BatchAllocationEvent])) + val errMsg = intercept[RuntimeException] { + tracker1.allocateBlocksToBatch(1) + } + assert(errMsg.getMessage === "Not able to write BatchAllocationEvent") + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + tracker1.getBlocksOfBatch(1) shouldEqual Map.empty + tracker1.getBlocksOfBatchAndStream(1, streamId) shouldEqual Seq.empty + + // Allocate the blocks to a batch and verify that all of them have been allocated + reset(tracker1) + tracker1.allocateBlocksToBatch(2) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker1.hasUnallocatedReceivedBlocks should be (false) + tracker1.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker1.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + + tracker1.stop() + + // Recover from WAL to see the correctness + val tracker2 = createTracker(recoverFromWriteAheadLog = true) + tracker2.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker2.hasUnallocatedReceivedBlocks should be (false) + tracker2.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker2.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates @@ -312,7 +355,7 @@ class ReceivedBlockTrackerSuite recoverFromWriteAheadLog: Boolean = false, clock: Clock = new SystemClock): ReceivedBlockTracker = { val cpDirOption = if (setCheckpointDir) Some(checkpointDirectory.toString) else None - val tracker = new ReceivedBlockTracker( + var tracker = new ReceivedBlockTracker( conf, hadoopConf, Seq(streamId), clock, recoverFromWriteAheadLog, cpDirOption) allReceivedBlockTrackers += tracker tracker